[
  {
    "path": ".github/workflows/codespell.yml",
    "content": "---\nname: Codespell\n\non:\n  push:\n    branches: [master]\n  pull_request:\n    branches: [master]\n\npermissions:\n  contents: read\n\njobs:\n  codespell:\n    name: Check for spelling errors\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v3\n      - name: Codespell\n        uses: codespell-project/actions-codespell@v2\n"
  },
  {
    "path": ".github/workflows/run_tests_nnunet.yml",
    "content": "name: Run nnunet tests on all OS\non:\n  push:\n    paths-ignore:\n      - '**.md'\njobs:\n\n  run-tests:\n    strategy:\n      matrix:\n        # os: [ubuntu-latest, windows-latest, macos-latest]  # fails on windows until https://github.com/MIC-DKFZ/nnUNet/issues/2396 is resolved\n        os: [ubuntu-latest, macos-latest]\n        python-version: [\"3.10\"]\n    runs-on: ${{ matrix.os }}\n\n    steps:\n    - uses: actions/checkout@v4\n\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v5\n      with:\n        python-version: ${{ matrix.python-version }}\n\n    - name: Download and extract weights and download example file\n      # use python instead of curl to avoid the need to install curl (more difficult for macos)\n      run: |\n        mkdir -p $HOME/github_actions_nnunet/results\n        python -c \"import urllib.request; urllib.request.urlretrieve('https://github.com/wasserth/TotalSegmentator/releases/download/v2.0.0-weights/Dataset300_body_6mm_1559subj.zip', '$HOME/github_actions_nnunet/results/tmp_download_file.zip')\"\n        unzip -o $HOME/github_actions_nnunet/results/tmp_download_file.zip -d $HOME/github_actions_nnunet/results\n        rm $HOME/github_actions_nnunet/results/tmp_download_file.zip\n\n    - name: Install dependencies on Ubuntu\n      if: runner.os == 'Linux'\n      run: |\n        python -m pip install --upgrade pip\n        pip install pytest Cython\n        pip install torch==2.4.0 -f https://download.pytorch.org/whl/cpu\n        pip install .\n\n    - name: Install dependencies on Windows / MacOS\n      if: runner.os == 'Windows' || runner.os == 'macOS'\n      run: |\n        python -m pip install --upgrade pip\n        pip install pytest Cython\n        pip install torch==2.4.0\n        pip install .\n\n    - name: Run test script\n      run: python nnunetv2/tests/integration_tests/run_nnunet_inference.py\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*,cover\n.hypothesis/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# IPython Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# dotenv\n.env\n\n# virtualenv\nvenv/\nENV/\n\n# Spyder project settings\n.spyderproject\n\n# Rope project settings\n.ropeproject\n\n*.memmap\n*.png\n*.zip\n*.npz\n*.npy\n*.jpg\n*.jpeg\n.idea\n*.txt\n.idea/*\n*.png\n# *.nii.gz  # nifti files needed for example_data for github actions tests\n*.nii\n*.tif\n*.bmp\n*.pkl\n*.xml\n*.pkl\n*.pdf\n*.png\n*.jpg\n*.jpeg\n\n*.model\n\n!documentation/assets/scribble_example.png\n\nCLAUDE.md\nAGENTS.md"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to nnU-Net\n\nThank you for your interest in contributing to nnU-Net.\n\nnnU-Net is developed and maintained by researchers at DKFZ. There is no dedicated funding or staff for maintaining the \nrepository, and development happens alongside research and teaching responsibilities. Our bandwidth for reviewing \nexternal contributions is therefore limited, and review times may be long.\n\n## General principles\n\nnnU-Net is intentionally designed to be focused, stable, and generally applicable across datasets and use cases. \nContributions should respect this philosophy and should not introduce unnecessary complexity or specialization.\n\nNew functionality must either be generally valid across datasets and setups or convincingly benefit a large enough \nportion of the user base. We aim to avoid bloating the framework or increasing its complexity further.\n\n## How to contribute\n\nFor larger features and refactors, please open a GitHub issue to discuss the idea before starting work. Tag\n@FabianIsensee so that the discussion doesn't get missed.\n\nTo submit a contribution, fork the repository, make your changes on a branch, and open a pull request.\n\n## Bug reports and bug fixes\n\nBug reports must include a minimal reproducible example. Without a repro, it is usually impossible for us to \ninvestigate issues.\n\nPull requests fixing bugs should also include a clear reproduction of the issue and an explanation of how the fix \nresolves it.\n\n## Performance improvements\n\nIf a pull request claims performance improvements, it must include benchmarks demonstrating the effect. The benchmark \nsetup must be described clearly enough for us to reproduce the results independently. We may run additional \ntests ourselves before merging.\n\n## Contributions that are unlikely to be merged\n\nTo keep the framework maintainable and the workload manageable on our end, we deprioritize:\n\n- dataset-specific code\n- features that only apply to niche setups\n- narrow custom architectures or training pipelines\n- large refactorings without prior discussion\n- small PRs fixing minor typos or formatting issues\n\n## Final note\n\nWe appreciate the effort people invest in improving nnU-Net!"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [2019] [Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License."
  },
  {
    "path": "documentation/__init__.py",
    "content": ""
  },
  {
    "path": "documentation/benchmarking.md",
    "content": "# nnU-Netv2 benchmarks\n\nDoes your system run like it should? Is your epoch time longer than expected? What epoch times should you expect?\n\nLook no further for we have the solution here!\n\n## What does the nnU-netv2 benchmark do?\n\nnnU-Net's benchmark trains models for 5 epochs. At the end, the fastest epoch will \nbe noted down, along with the GPU name, torch version and cudnn version. You can find the benchmark output in the \ncorresponding nnUNet_results subfolder (see example below). Don't worry, we also provide scripts to collect your \nresults. Or you just start a benchmark and look at the console output. Everything is possible. Nothing is forbidden.\n\nThe benchmark implementation revolves around two trainers:\n- `nnUNetTrainerBenchmark_5epochs` runs a regular training for 5 epochs. When completed, writes a .json file with the fastest \nepoch time as well as the GPU used and the torch and cudnn versions. Useful for speed testing the entire pipeline \n(data loading, augmentation, GPU training)\n- `nnUNetTrainerBenchmark_5epochs_noDataLoading` is the same, but it doesn't do any data loading or augmentation. It \njust presents dummy arrays to the GPU. Useful for checking pure GPU speed.\n\n## How to run the nnU-Netv2 benchmark?\nIt's quite simple, actually. It looks just like a regular nnU-Net training.\n\nWe provide reference numbers for some of the Medical Segmentation Decathlon datasets because they are easily \naccessible: [download here](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2). If it needs to be \nquick and dirty, focus on Tasks 2 and 4. Download and extract the data and convert them to the nnU-Net format with \n`nnUNetv2_convert_MSD_dataset`. \nRun `nnUNetv2_plan_and_preprocess` for them.\n\nThen, for each dataset, run the following commands (only one per GPU! Or one after the other):\n\n```bash\nnnUNetv2_train DATSET_ID 2d 0 -tr nnUNetTrainerBenchmark_5epochs\nnnUNetv2_train DATSET_ID 3d_fullres 0 -tr nnUNetTrainerBenchmark_5epochs\nnnUNetv2_train DATSET_ID 2d 0 -tr nnUNetTrainerBenchmark_5epochs_noDataLoading\nnnUNetv2_train DATSET_ID 3d_fullres 0 -tr nnUNetTrainerBenchmark_5epochs_noDataLoading\n```\n\nIf you want to inspect the outcome manually, check (for example!) your \n`nnUNet_results/DATASET_NAME/nnUNetTrainerBenchmark_5epochs__nnUNetPlans__3d_fullres/fold_0/` folder for the `benchmark_result.json` file.\n\nNote that there can be multiple entries in this file if the benchmark was run on different GPU types, torch versions or cudnn versions!\n\nIf you want to summarize your results like we did in our [results](#results), check the \n[summary script](../nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py). Here you need to change the \ntorch version, cudnn version and dataset you want to summarize, then execute the script. You can find the exact \nvalues you need to put there in one of your `benchmark_result.json` files.\n\n## Results\nWe have tested a variety of GPUs and summarized the results in a \n[spreadsheet](https://docs.google.com/spreadsheets/d/12Cvt_gr8XU2qWaE0XJk5jJlxMEESPxyqW0CWbQhTNNY/edit?usp=sharing). \nNote that you can select the torch and cudnn versions at the bottom! There may be comments in this spreadsheet. Read them!\n\n## Result interpretation\n\nResults are shown as epoch time in seconds. Lower is better (duh). Epoch times can fluctuate between runs, so as \nlong as you are within like 5-10% of the numbers we report, everything should be dandy. \n\nIf not, here is how you can try to find the culprit!\n\nThe first thing to do is to compare the performance between the `nnUNetTrainerBenchmark_5epochs_noDataLoading` and \n`nnUNetTrainerBenchmark_5epochs` trainers. If the difference is about the same as we report in our spreadsheet, but \nboth your numbers are worse, the problem is with your GPU:\n\n- Are you certain you compare the correct GPU? (duh)\n- If yes, then you might want to install PyTorch in a different way. Never `pip install torch`! Go to the\n[PyTorch installation](https://pytorch.org/get-started/locally/) page, select the most recent cuda version your \nsystem supports and only then copy and execute the correct command! Either pip or conda should work\n- If the problem is still not fixed, we recommend you try \n[compiling pytorch from source](https://github.com/pytorch/pytorch#from-source). It's more difficult but that's \nhow we roll here at the DKFZ (at least the cool kids here).\n- Another thing to consider is to try exactly the same torch + cudnn version as we did in our spreadsheet. \nSometimes newer versions can actually degrade performance and there might be bugs from time to time. Older versions \nare also often a lot slower!\n- Finally, some very basic things that could impact your GPU performance: \n  - Is the GPU cooled adequately? Check the temperature with `nvidia-smi`. Hot GPUs throttle performance in order to not self-destruct\n  - Is your OS using the GPU for displaying your desktop at the same time? If so then you can expect a performance \n  penalty (I dunno like 10% !?). That's expected and OK.\n  - Are other users using the GPU as well?\n\n\nIf you see a large performance difference between `nnUNetTrainerBenchmark_5epochs_noDataLoading` (fast) and \n`nnUNetTrainerBenchmark_5epochs` (slow) then the problem might be related to data loading and augmentation. As a \nreminder, nnU-net does not use pre-augmented images (offline augmentation) but instead generates augmented training \nsamples on the fly during training (no, you cannot switch it to offline). This requires that your system can do partial \nreads of the image files fast enough (SSD storage required!) and that your CPU is powerful enough to run the augmentations.\n\nCheck the following:\n\n- [CPU bottleneck] How many CPU threads are running during the training? nnU-Net uses 12 processes for data augmentation by default. \nIf you see those 12 running constantly during training, consider increasing the number of processes used for data \naugmentation (provided there is headroom on your CPU!). Increase the number until you see less active workers than \nyou configured (or just set the number to 32 and forget about it). You can do so by setting the `nnUNet_n_proc_DA` \nenvironment variable (Linux: `export nnUNet_n_proc_DA=24`). Read [here](set_environment_variables.md) on how to do this.\nIf your CPU does not support more processes (setting more processes than your CPU has threads makes \nno sense!) you are out of luck and in desperate need of a system upgrade!\n- [I/O bottleneck] If you don't see 12 (or nnUNet_n_proc_DA if you set it) processes running but your training times \nare still slow then open up `top` (sorry, Windows users. I don't know how to do this on Windows) and look at the value \nleft of 'wa' in the row that begins \nwith '%Cpu (s)'. If this is >1.0 (arbitrarily set threshold here, essentially look for unusually high 'wa'. In a \nhealthy training 'wa' will be almost 0) then your storage cannot keep up with data loading. Make sure to set \nnnUNet_preprocessed to a folder that is located on an SSD. nvme is preferred over SATA. PCIe3 is enough. 3000MB/s \nsequential read recommended.\n- [funky stuff] Sometimes there is funky stuff going on, especially when batch sizes are large, files are small and \npatch sizes are small as well. As part of the data loading process, nnU-Net needs to open and close a file for each \ntraining sample. Now imagine a dataset like Dataset004_Hippocampus where for the 2d config we have a batch size of \n366 and we run 250 iterations in <10s on an A100. That's a lotta files per second (366 * 250 / 10 = 9150 files per second). \nOof. If the files are on some network drive (even if it's nvme) then (probably) good night. The good news: nnU-Net\nhas got you covered: add `export nnUNet_keep_files_open=True` to your .bashrc and the problem goes away. The neat \npart: it causes new problems if you are not allowed to have enough open files. You may have to increase the number \nof allowed open files. `ulimit -n` gives your current limit (Linux only). It should not be something like 1024. \nIncreasing that to 65535 works well for me. See here for how to change these limits: \n[Link](https://kupczynski.info/posts/ubuntu-18-10-ulimits/) \n(works for Ubuntu 18, google for your OS!).\n\n"
  },
  {
    "path": "documentation/changelog.md",
    "content": "# What is different in v2?\n\n- We now support **hierarchical labels** (named regions in nnU-Net). For example, instead of training BraTS with the \n'edema', 'necrosis' and 'enhancing tumor' labels you can directly train it on the target areas 'whole tumor', \n'tumor core' and 'enhancing tumor'. See [here](region_based_training.md) for a detailed description + also have a look at the \n[BraTS 2021 conversion script](../nnunetv2/dataset_conversion/Dataset137_BraTS21.py).\n- Cross-platform support. Cuda, mps (Apple M1/M2) and of course CPU support! Simply select the device with \n`-device` in `nnUNetv2_train` and `nnUNetv2_predict`.\n- Unified trainer class: nnUNetTrainer. No messing around with cascaded trainer, DDP trainer, region-based trainer, \nignore trainer etc. All default functionality is in there!\n- Supports more input/output data formats through ImageIO classes.\n- I/O formats can be extended by implementing new Adapters based on `BaseReaderWriter`.\n- The nnUNet_raw_cropped folder no longer exists -> saves disk space at no performance penalty. magic! (no jk the \nsaving of cropped npz files was really slow, so it's actually faster to crop on the fly).\n- Preprocessed data and segmentation are stored in different files when unpacked. Seg is stored as int8 and thus \ntakes 1/4 of the disk space per pixel (and I/O throughput) as in v1.\n- Native support for multi-GPU (DDP) TRAINING. \nMulti-GPU INFERENCE should still be run with `CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] -num_parts Y -part_id X`. \nThere is no cross-GPU communication in inference, so it doesn't make sense to add additional complexity with DDP.\n- All nnU-Net functionality is now also accessible via API. Check the corresponding entry point in `setup.py` to see \nwhat functions you need to call.\n- Dataset fingerprint is now explicitly created and saved in a json file (see nnUNet_preprocessed).\n\n- Complete overhaul of plans files (read also [this](explanation_plans_files.md):\n  - Plans are now .json and can be opened and read more easily\n  - Configurations are explicitly named (\"3d_fullres\" , ...)\n  - Configurations can inherit from each other to make manual experimentation easier\n  - A ton of additional functionality is now included in and can be changed through the plans, for example normalization strategy, resampling etc.\n  - Stages of the cascade are now explicitly listed in the plans. 3d_lowres has 'next_stage' (which can also be a \n  list of configurations!). 3d_cascade_fullres has a 'previous_stage' entry. By manually editing plans files you can \n  now connect anything you want, for example 2d with 3d_fullres or whatever. Be wild! (But don't create cycles!)\n  - Multiple configurations can point to the same preprocessed data folder to save disk space. Careful! Only \n  configurations that use the same spacing, resampling, normalization etc. should share a data source! By default, \n  3d_fullres and 3d_cascade_fullres share the same data\n  - Any number of configurations can be added to the plans (remember to give them a unique \"data_identifier\"!)\n\nFolder structures are different and more user-friendly:\n- nnUNet_preprocessed\n  - By default, preprocessed data is now saved as: `nnUNet_preprocessed/DATASET_NAME/PLANS_IDENTIFIER_CONFIGURATION` to clearly link them to their corresponding plans and configuration \n  - Name of the folder containing the preprocessed images can be adapted with the `data_identifier` key.\n- nnUNet_results\n  - Results are now sorted as follows: DATASET_NAME/TRAINERCLASS__PLANSIDENTIFIER__CONFIGURATION/FOLD\n\n## What other changes are planned and not yet implemented?\n- Integration into MONAI (together with our friends at Nvidia)\n- New pretrained weights for a large number of datasets (coming very soon))\n\n\n[//]: # (- nnU-Net now also natively supports an **ignore label**. Pixels with this label will not contribute to the loss. )\n\n[//]: # (Use this to learn from sparsely annotated data, or excluding irrelevant areas from training. Read more [here]&#40;ignore_label.md&#41;.)"
  },
  {
    "path": "documentation/competitions/AortaSeg24.md",
    "content": "Authors: \\\nMaximilian Rokuss, Michael Baumgartner, Yannick Kirchhoff, Klaus H. Maier-Hein*, Fabian Isensee*\n\n*: equal contribution\n\nAuthor Affiliations:\\\nDivision of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg \\\nHelmholtz Imaging\n\n# Introduction\n\nThis document describes our submission to the [AortaSeg24 Challenge](https://aortaseg24.grand-challenge.org/). \nOur model is essentially a nnU-Net ResEnc L with modified data augmentation. We disable left/right mirroring and use the heavy data augmentation [DA5 Trainer](../../nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py). Training was performed on an A100 40GB GPU.\n\n# Experiment Planning and Preprocessing\nAfter converting the data into the [nnUNet format](../../../nnUNet/documentation/dataset_format.md) (either keep and just rename the .mha files or convert them to .nii.gz), you can run the preprocessing:\n\n```bash\nnnUNetv2_plan_and_preprocess -d 610 -c 3d_fullres -pl nnUNetPlannerResEncL -np 16\n```\n\n# Training\nWe train our model using:\n\n```bash\nnnUNetv2_train 610 3d_fullres all -p nnUNetResEncUNetLPlans -tr nnUNetTrainer_onlyMirror01_DA5\n```\nModels are trained from scratch. We train one model using all the images and a five fold cross validation ensemble for the submission.\n\nWe recommend to increase the number of processes used for data augmentation. Otherwise you can run into CPU bottlenecks.\nUse `export nnUNet_n_proc_DA=32` or higher (if your system permits!).\n\n# Inference\nFor inference you can use the default [nnUNet inference functionalities](../../../nnUNet/documentation/how_to_use_nnunet.md). Specifically, once the training is finished, run:\n\n```bash\nnnUNetv2_predict_from_modelfolder -i INPUT_FOLDER -o OUTPUT_FOLDER -m MODEL_FOLDER -f all\n```\n\nfor the single model trained on all the data and \n\n```bash\nnnUNetv2_predict_from_modelfolder -i INPUT_FOLDER -o OUTPUT_FOLDER -m MODEL_FOLDER\n```\n\nfor the five fold ensemble."
  },
  {
    "path": "documentation/competitions/AutoPETII.md",
    "content": "# Look Ma, no code: fine tuning nnU-Net for the AutoPET II challenge by only adjusting its JSON plans\n\nPlease cite our paper :-*\n\n```text\nCOMING SOON\n```\n\n## Intro\n\nSee the [Challenge Website](https://autopet-ii.grand-challenge.org/) for details on the challenge.\n\nOur solution to this challenge rewuires no code changes at all. All we do is optimize nnU-Net's hyperparameters \n(architecture, batch size, patch size) through modifying the nnUNetplans.json file.\n\n## Prerequisites\nUse the latest pytorch version!\n\nWe recommend you use the latest nnU-Net version as well! We ran our trainings with commit 913705f which you can try in case something doesn't work as expected:\n`pip install git+https://github.com/MIC-DKFZ/nnUNet.git@913705f`\n\n## How to reproduce our trainings\n\n### Download and convert the data\n1. Download and extract the AutoPET II dataset\n2. Convert it to nnU-Net format by running `python nnunetv2/dataset_conversion/Dataset221_AutoPETII_2023.py FOLDER` where folder is the extracted AutoPET II dataset.\n\n### Experiment planning and preprocessing\nWe deviate a little from the standard nnU-Net procedure because all our experiments are based on just the 3d_fullres configuration\n\nRun the following commands:\n   - `nnUNetv2_extract_fingerprint -d 221` extracts the dataset fingerprint \n   - `nnUNetv2_plan_experiment -d 221` does the planning for the plain unet\n   - `nnUNetv2_plan_experiment -d 221 -pl ResEncUNetPlanner` does the planning for the residual encoder unet\n   - `nnUNetv2_preprocess -d 221 -c 3d_fullres` runs all the preprocessing we need\n\n### Modification of plans files\nPlease read the [information on how to modify plans files](../explanation_plans_files.md) first!!!\n\n\nIt is easier to have everything in one plans file, so the first thing we do is transfer the ResEnc UNet to the \ndefault plans file. We use the configuration inheritance feature of nnU-Net to make it use the same data as the \n3d_fullres configuration.\nAdd the following to the 'configurations' dict in 'nnUNetPlans.json':\n\n```json\n        \"3d_fullres_resenc\": {\n            \"inherits_from\": \"3d_fullres\",\n            \"network_arch_class_name\": \"ResidualEncoderUNet\",\n            \"n_conv_per_stage_encoder\": [\n                1,\n                3,\n                4,\n                6,\n                6,\n                6\n            ],\n            \"n_conv_per_stage_decoder\": [\n                1,\n                1,\n                1,\n                1,\n                1\n            ]\n        },\n```\n\n(these values are basically just copied from the 'nnUNetResEncUNetPlans.json' file! With everything redundant being omitted thanks to inheritance from 3d_fullres)\n\nNow we crank up the patch and batch sizes. Add the following configurations:\n```json\n        \"3d_fullres_resenc_bs80\": {\n            \"inherits_from\": \"3d_fullres_resenc\",\n            \"batch_size\": 80\n            },\n        \"3d_fullres_resenc_192x192x192_b24\": {\n            \"inherits_from\": \"3d_fullres_resenc\",\n            \"patch_size\": [\n                192,\n                192,\n                192\n            ],\n            \"batch_size\": 24\n        }\n```\n\nSave the file (and check for potential Syntax Errors!)\n\n### Run trainings\nTraining each model requires 8 Nvidia A100 40GB GPUs. Expect training to run for 5-7 days. You'll need a really good \nCPU to handle the data augmentation! 128C/256T are a must! If you have less threads available, scale down nnUNet_n_proc_DA accordingly.\n\n```bash\nnnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 0 -num_gpus 8\nnnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 1 -num_gpus 8\nnnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 2 -num_gpus 8\nnnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 3 -num_gpus 8\nnnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 4 -num_gpus 8\n\nnnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 0 -num_gpus 8\nnnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 1 -num_gpus 8\nnnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 2 -num_gpus 8\nnnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 3 -num_gpus 8\nnnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 4 -num_gpus 8\n```\n\nDone!\n\n(We also provide pretrained weights in case you don't want to invest the GPU resources, see below)\n\n## How to make predictions with pretrained weights\nOur final model is an ensemble of two configurations:\n- ResEnc UNet with batch size 80\n- ResEnc UNet with patch size 192x192x192 and batch size 24\n\nTo run inference with these models, do the following:\n\n1. Download the pretrained model weights from [Zenodo](https://zenodo.org/record/8362371)\n2. Install both .zip files using `nnUNetv2_install_pretrained_model_from_zip`\n3. Make sure \n4. Now you can run inference on new cases with `nnUNetv2_predict`:\n   - `nnUNetv2_predict -i INPUT -o OUTPUT1 -d 221 -c 3d_fullres_resenc_bs80 -f 0 1 2 3 4 -step_size 0.6 --save_probabilities`   \n   - `nnUNetv2_predict -i INPUT -o OUTPUT2 -d 221 -c 3d_fullres_resenc_192x192x192_b24 -f 0 1 2 3 4 --save_probabilities`\n   - `nnUNetv2_ensemble -i OUTPUT1 OUTPUT2 -o OUTPUT_ENSEMBLE`\n\nNote that our inference Docker omitted TTA via mirroring along the axial direction during prediction (only sagittal + \ncoronal mirroring). This was\ndone to keep the inference time below 10 minutes per image on a T4 GPU (we actually never tested whether we could \nhave left this enabled). Just leave it on! You can also leave the step_size at default for the 3d_fullres_resenc_bs80."
  },
  {
    "path": "documentation/competitions/FLARE24/Task_1/__init__.py",
    "content": ""
  },
  {
    "path": "documentation/competitions/FLARE24/Task_1/inference_flare_task1.py",
    "content": "from typing import Union, Tuple\nimport argparse\nimport numpy as np\nimport os\nfrom os.path import join\nfrom pathlib import Path\nfrom time import time\nimport torch\nfrom torch._dynamo import OptimizedModule\n\nfrom nnunetv2.utilities.label_handling.label_handling import LabelManager\n\nfrom acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice\nfrom batchgenerators.utilities.file_and_folder_operations import load_json\n\nimport nnunetv2\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.utilities.find_class_by_name import recursive_find_python_class\nfrom nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager\nfrom nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\nfrom nnunetv2.imageio.nibabel_reader_writer import NibabelIOWithReorient\n\n\nclass FlarePredictor(nnUNetPredictor):\n    def initialize_from_trained_model_folder(self, model_training_output_dir: str,\n                                             use_folds: Union[Tuple[Union[int, str]], None],\n                                             checkpoint_name: str = 'checkpoint_final.pth'):\n        \"\"\"\n        This is used when making predictions with a trained model\n        \"\"\"\n        if use_folds is None:\n            use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name)\n\n        dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))\n        plans = load_json(join(model_training_output_dir, 'plans.json'))\n        plans_manager = PlansManager(plans)\n\n        if isinstance(use_folds, str):\n            use_folds = [use_folds]\n\n        parameters = []\n        for i, f in enumerate(use_folds):\n            f = int(f) if f != 'all' else f\n            checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),\n                                    map_location=torch.device('cpu'))\n            if i == 0:\n                trainer_name = checkpoint['trainer_name']\n                configuration_name = checkpoint['init_args']['configuration']\n                inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \\\n                    'inference_allowed_mirroring_axes' in checkpoint.keys() else None\n\n            parameters.append(checkpoint['network_weights'])\n\n        configuration_manager = plans_manager.get_configuration(configuration_name)\n        # restore network\n        num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)\n        trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], \"training\", \"nnUNetTrainer\"),\n                                                    trainer_name, 'nnunetv2.training.nnUNetTrainer')\n        if trainer_class is None:\n            raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '\n                               f'Please place it there (in any .py file)!')\n        network = trainer_class.build_network_architecture(\n            configuration_manager.network_arch_class_name,\n            configuration_manager.network_arch_init_kwargs,\n            configuration_manager.network_arch_init_kwargs_req_import,\n            num_input_channels,\n            plans_manager.get_label_manager(dataset_json).num_segmentation_heads,\n            enable_deep_supervision=False\n        )\n\n        self.plans_manager = plans_manager\n        self.configuration_manager = configuration_manager\n        self.list_of_parameters = parameters\n        self.network = network\n        self.dataset_json = dataset_json\n        self.trainer_name = trainer_name\n        self.allowed_mirroring_axes = inference_allowed_mirroring_axes\n        self.label_manager = plans_manager.get_label_manager(dataset_json)\n\n        if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \\\n                and not isinstance(self.network, OptimizedModule):\n            self.network = torch.compile(self.network)\n\n\ndef convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray],\n                                                                plans_manager: PlansManager,\n                                                                configuration_manager: ConfigurationManager,\n                                                                label_manager: LabelManager,\n                                                                properties_dict: dict,\n                                                                return_probabilities: bool = False,\n                                                                num_threads_torch: int = default_num_processes):\n    old_threads = torch.get_num_threads()\n    torch.set_num_threads(num_threads_torch)\n\n    # resample to original shape\n    current_spacing = configuration_manager.spacing if \\\n        len(configuration_manager.spacing) == \\\n        len(properties_dict['shape_after_cropping_and_before_resampling']) else \\\n        [properties_dict['spacing'][0], *configuration_manager.spacing]\n    predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits,\n                                            properties_dict['shape_after_cropping_and_before_resampling'],\n                                            current_spacing,\n                                            properties_dict['spacing'])\n    # return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because\n    # apply_inference_nonlin will convert to torch\n    # predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits)\n    segmentation = predicted_logits.argmax(0)\n    del predicted_logits\n    # segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities)\n\n    # segmentation may be torch.Tensor but we continue with numpy\n    if isinstance(segmentation, torch.Tensor):\n        segmentation = segmentation.cpu().numpy()\n\n    # put segmentation in bbox (revert cropping)\n    segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'],\n                                              dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16)\n    slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping'])\n    segmentation_reverted_cropping[slicer] = segmentation\n    del segmentation\n\n    # revert transpose\n    segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward)\n    torch.set_num_threads(old_threads)\n    return segmentation_reverted_cropping\n\n\ndef export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict,\n                                  configuration_manager: ConfigurationManager,\n                                  plans_manager: PlansManager,\n                                  dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str,\n                                  save_probabilities: bool = False):\n\n    if isinstance(dataset_json_dict_or_file, str):\n        dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)\n\n    label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)\n    ret = convert_predicted_logits_to_segmentation_with_correct_shape(\n        predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict,\n        return_probabilities=save_probabilities\n    )\n    del predicted_array_or_file\n\n    segmentation_final = ret\n\n    rw = NibabelIOWithReorient()\n    rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'],\n                 properties_dict)\n\n\ndef predict_flare(input_dir, output_dir, model_folder, folds=(\"all\",)):\n    input_dir = Path(input_dir)\n    output_dir = Path(output_dir)\n    input_files = sorted(input_dir.glob(\"*.nii.gz\"))\n    output_files = [str(output_dir / f.name[:-12]) for f in input_files]\n    for input_file, output_file in zip(input_files, output_files):\n        print(f\"Predicting {input_file.name}\")\n        start = time()\n        plans_manager = PlansManager(load_json(join(model_folder, 'plans.json')))\n        configuration_manager = plans_manager.get_configuration(\"3d_fullres\")\n        dataset_json = load_json(join(model_folder, 'dataset.json'))\n        rw = NibabelIOWithReorient()\n        image, props = rw.read_images([input_file,])\n        with torch.no_grad():\n            predictor = FlarePredictor(tile_step_size=0.5, use_mirroring=False)\n            predictor.initialize_from_trained_model_folder(model_folder, use_folds=folds)\n            preprocessor = configuration_manager.preprocessor_class(verbose=False)\n            data, _ = preprocessor.run_case_npy(image,\n                                                None,\n                                                props,\n                                                plans_manager,\n                                                configuration_manager,\n                                                dataset_json)\n            data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)\n            predicted_logits = predictor.predict_logits_from_preprocessed_data(data).cpu()\n            export_prediction_from_logits(predicted_logits, props, configuration_manager,\n                                            plans_manager, dataset_json, output_file,\n                                            False)\n        print(f\"Prediction time: {time() - start:.2f}s\")\n\n\nif __name__ == '__main__':\n    os.environ['nnUNet_compile'] = 'f'\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-i\", \"--input\", default=\"/workspace/inputs\")\n    parser.add_argument(\"-o\", \"--output\", default=\"/workspace/outputs\")\n    parser.add_argument(\"-m\", \"--model\", default=\"/opt/app/_trained_model\")\n    parser.add_argument(\"-f\", \"--folds\", nargs=\"+\", default=[\"all\"])\n    args = parser.parse_args()\n    predict_flare(args.input, args.output, args.model, args.folds)"
  },
  {
    "path": "documentation/competitions/FLARE24/Task_1/readme.md",
    "content": "Authors: \\\nYannick Kirchhoff, Maximilian Rouven Rokuss, Benjamin Hamm, Ashis Ravindran, Constantin Ulrich, Klaus Maier-Hein<sup>&#8224;</sup>, Fabian Isensee<sup>&#8224;</sup>\n\n&#8224;: equal contribution\n\n# Introduction\n\nThis document describes our contribution to [Task 1 of the FLARE24 Challenge](https://www.codabench.org/competitions/2319/).\nOur model is basically is a default nnU-Net trained with larger batch size of 4 and 8, respectively. We submitted the batch size 8 model and an ensemble of the batch size 4 and batch size 8 models to the final test set.\n\n# Experiment Planning and Preprocessing\n\nBring the downloaded data into the [nnU-Net format](../../../nnUNet/documentation/dataset_format.md) and add the dataset.json file as given here:\n\n```json\n{\n    \"name\": \"Dataset301_FLARE24Task1_labeled\",\n    \"description\": \"Pan Cancer Segmentation\",\n    \"labels\": {\n        \"background\": 0,\n        \"lesion\": 1\n    },\n    \"file_ending\": \".nii.gz\",\n    \"channel_names\": {\n        \"0\": \"CT\"\n    },\n    \"numTraining\": 5000\n}\n```\n\nAfterwards you can run the default nnU-Net planning and preprocessing\n\n```bash\nnnUNetv2_plan_and_preprocess -d 301 -c 3d_fullres\n```\n\n## Edit the plans files\n\nIn the generated `nnUNetPlans.json` file add the following configurations\n\n```json\n        \"3d_fullres_bs4\": {\n            \"inherits_from\": \"3d_fullres\",\n            \"batch_size\": 4\n        },\n        \"3d_fullres_bs8\": {\n            \"inherits_from\": \"3d_fullres\",\n            \"batch_size\": 8\n        },\n        \"3d_fullres_bs4u8\": {\n            \"inherits_from\": \"3d_fullres\",\n            \"batch_size\": 48\n        }\n```\n\nNote, the last one is only used for the ensemble model during inference!\n\n# Model training\n\nRun the following commands to train the models with batch size 4 and 8. The large batch size helps stabilize the training despite the partial labels present in the dataset as well as handling the large number of scans in the dataset. We therefore keep the number of epochs at 1000.\n\n```bash\nnnUNetv2_train 301 3d_fullres_bs4 all\n\nnnUNetv2_train 301 3d_fullres_bs8 all\n```\n\n# Inference\n\nOur inference is optimized for efficient single scan prediction. For best performance, we strongly recommend running inference using the default `nnUNetv2_predict` command!\n\nIn order to run inference with the ensemble model you need to create a folder called `nnUNetTrainer__nnUNetPlans__3d_fullres_bs4u8` in the results folder and copy the `dataset.json`, `dataset_fingerprint.json` and `plans.json` from one of the other results folder as well as the `fold_all` from both trainings as `fold_0` and `fold_1`, respectively, into this new folder. This allows for easy ensembling of both models.\n\nTo run inference simply run the following commands with `folds` set to `all` for single model inference or `0 1` for the ensemble. `model_folder` is the folder containing the training results, i.e. for example `nnUNetTrainer__nnUNetPlans__3d_fullres_bs8`.\n\n```bash\npython inference_flare_task1.py -i input_folder -o output_folder -m model_folder -f folds\n```"
  },
  {
    "path": "documentation/competitions/FLARE24/Task_2/__init__.py",
    "content": ""
  },
  {
    "path": "documentation/competitions/FLARE24/Task_2/inference_flare_task2.py",
    "content": "from typing import Union, List, Tuple\nimport argparse\nimport itertools\nimport multiprocessing\nimport numpy as np\nimport os\nfrom os.path import join\nfrom pathlib import Path\nfrom time import time\nimport torch\nfrom torch._dynamo import OptimizedModule\nfrom tqdm import tqdm\n\nimport openvino as ov\nimport openvino.properties.hint as hints\n\nfrom acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice\nfrom acvl_utils.cropping_and_padding.padding import pad_nd_image\nfrom batchgenerators.utilities.file_and_folder_operations import load_json\n\nfrom nnunetv2.utilities.label_handling.label_handling import LabelManager\nimport nnunetv2\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.inference.data_iterators import PreprocessAdapterFromNpy\nfrom nnunetv2.inference.sliding_window_prediction import compute_gaussian\nfrom nnunetv2.utilities.find_class_by_name import recursive_find_python_class\nfrom nnunetv2.utilities.helpers import empty_cache, dummy_context\nfrom nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager\nfrom nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\n\n\ntorch.set_num_threads(multiprocessing.cpu_count())\n\n\nclass FlarePredictor(nnUNetPredictor):\n    def __init__(self,\n                 tile_step_size: float = 0.5,\n                 use_gaussian: bool = True,\n                 use_mirroring: bool = True,\n                 perform_everything_on_device: bool = False,\n                 device: torch.device = torch.device('cpu'),\n                 verbose: bool = False,\n                 verbose_preprocessing: bool = False,\n                 allow_tqdm: bool = True,\n                 ):\n        super().__init__(tile_step_size, use_gaussian, use_mirroring, perform_everything_on_device, device, verbose,\n                         verbose_preprocessing, allow_tqdm)\n        if self.device == torch.device('cuda') or self.device == 'cuda':\n            raise RuntimeError('CUDA is not supported for this task')\n\n    def initialize_from_trained_model_folder(self, model_training_output_dir: str,\n                                             use_folds: Union[Tuple[Union[int, str]], None],\n                                             checkpoint_name: str = 'checkpoint_final.pth',\n                                             save_model: bool = True):\n        \"\"\"\n        This is used when making predictions with a trained model\n        \"\"\"\n        if use_folds is None:\n            use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name)\n\n        dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))\n        plans = load_json(join(model_training_output_dir, 'plans.json'))\n        plans_manager = PlansManager(plans)\n\n        if isinstance(use_folds, str):\n            use_folds = [use_folds]\n        assert len(use_folds) == 1, 'Only one fold is supported for this task'\n\n        parameters = []\n        for i, f in enumerate(use_folds):\n            f = int(f) if f != 'all' else f\n            checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),\n                                    map_location=torch.device('cpu'))\n            if i == 0:\n                trainer_name = checkpoint['trainer_name']\n                configuration_name = checkpoint['init_args']['configuration']\n                inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \\\n                    'inference_allowed_mirroring_axes' in checkpoint.keys() else None\n        \n            if save_model:\n                parameters.append(checkpoint['network_weights'])\n\n        configuration_manager = plans_manager.get_configuration(configuration_name)\n        if save_model:\n            num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)\n            trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], \"training\", \"nnUNetTrainer\"),\n                                                        trainer_name, 'nnunetv2.training.nnUNetTrainer')\n            if trainer_class is None:\n                raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '\n                                f'Please place it there (in any .py file)!')\n            network = trainer_class.build_network_architecture(\n                configuration_manager.network_arch_class_name,\n                configuration_manager.network_arch_init_kwargs,\n                configuration_manager.network_arch_init_kwargs_req_import,\n                num_input_channels,\n                plans_manager.get_label_manager(dataset_json).num_segmentation_heads,\n                enable_deep_supervision=False\n            )\n            self.network = network\n            self.allowed_mirroring_axes = inference_allowed_mirroring_axes\n\n        self.plans_manager = plans_manager\n        self.configuration_manager = configuration_manager\n        self.dataset_json = dataset_json\n        self.label_manager = plans_manager.get_label_manager(dataset_json)\n\n        if save_model:\n            if not isinstance(self.network, OptimizedModule):\n                self.network.load_state_dict(parameters[0])\n            else:\n                self.network._orig_mod.load_state_dict(parameters[0])\n            self.network.eval()\n\n        config = {hints.performance_mode: hints.PerformanceMode.LATENCY,\n                 hints.enable_cpu_pinning(): True,\n                 }\n        core = ov.Core()\n        core.set_property(\n            \"CPU\",\n            {hints.execution_mode: hints.ExecutionMode.PERFORMANCE},\n        )\n        if save_model:\n            input_tensor = torch.randn(1, num_input_channels, *configuration_manager.patch_size, requires_grad=False)\n            ov_model = ov.convert_model(self.network, example_input=input_tensor)\n            ov.save_model(ov_model, f\"{model_training_output_dir}/model.xml\")\n            import sys\n            sys.exit(0)\n        ov_model = core.read_model(f\"{model_training_output_dir}/model.xml\")\n        self.network = core.compile_model(ov_model, \"CPU\", config= config)\n\n    def predict_from_files(self,\n                           list_of_lists_or_source_folder: Union[str, List[List[str]]],\n                           output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]],\n                           save_probabilities: bool = False,\n                           overwrite: bool = True,\n                           num_processes_preprocessing: int = 1,\n                           num_processes_segmentation_export: int = 1,\n                           folder_with_segs_from_prev_stage: str = None,\n                           num_parts: int = 1,\n                           part_id: int = 0):\n        \"\"\"\n        This is nnU-Net's default function for making predictions. It works best for batch predictions\n        (predicting many images at once).\n        \"\"\"\n\n        list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \\\n            self._manage_input_and_output_lists(list_of_lists_or_source_folder,\n                                                output_folder_or_list_of_truncated_output_files,\n                                                folder_with_segs_from_prev_stage, overwrite, part_id, num_parts,\n                                                save_probabilities)\n        if len(list_of_lists_or_source_folder) == 0:\n            return\n\n        data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder,\n                                                                                 seg_from_prev_stage_files,\n                                                                                 output_filename_truncated,\n                                                                                 num_processes_preprocessing)\n        \n        return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export)\n\n    def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:\n        mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None\n        if self.use_openvino:\n            prediction = torch.from_numpy(self.network(x)[0])\n        else:\n            prediction = self.network(x)\n\n        if mirror_axes is not None:\n            assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'\n\n            mirror_axes = [m + 2 for m in mirror_axes]\n            axes_combinations = [\n                c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)\n            ]\n            for axes in axes_combinations:\n                if not self.is_openvino:\n                    prediction += torch.flip(self.network(torch.flip(x, axes)), axes)\n                else:\n                    temp_pred = torch.from_numpy(self.network(torch.flip(x, axes))[0])\n                    prediction += torch.flip(temp_pred, axes)\n\n            prediction /= (len(axes_combinations) + 1)\n        return prediction\n\n    def _internal_predict_sliding_window_return_logits(self,\n                                                       data: torch.Tensor,\n                                                       slicers,\n                                                       do_on_device: bool = True,\n                                                       ):\n        predicted_logits = n_predictions = prediction = gaussian = workon = None\n        results_device = self.device if do_on_device else torch.device('cpu')\n\n        try:\n            empty_cache(self.device)\n\n            # move data to device\n            if self.verbose:\n                print(f'move image to device {results_device}')\n            data = data.to(results_device)\n\n            # preallocate arrays\n            if self.verbose:\n                print(f'preallocating results arrays on device {results_device}')\n            predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),\n                                        dtype=torch.half,\n                                        device=results_device)\n            n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)\n\n            if self.use_gaussian:\n                gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,\n                                            value_scaling_factor=10,\n                                            device=results_device)\n            else:\n                gaussian = 1\n\n            if not self.allow_tqdm and self.verbose:\n                print(f'running prediction: {len(slicers)} steps')\n            for sl in tqdm(slicers, disable=not self.allow_tqdm):\n                workon = data[sl][None]\n                workon = workon.to(self.device)\n\n                prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)\n\n                if self.use_gaussian:\n                    prediction *= gaussian\n                predicted_logits[sl] += prediction\n                n_predictions[sl[1:]] += gaussian\n\n            predicted_logits /= n_predictions\n            # check for infs\n            if torch.any(torch.isinf(predicted_logits)):\n                raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '\n                                'reduce value_scaling_factor in compute_gaussian or increase the dtype of '\n                                'predicted_logits to fp32')\n        except Exception as e:\n            del predicted_logits, n_predictions, prediction, gaussian, workon\n            empty_cache(self.device)\n            empty_cache(results_device)\n            raise e\n        return predicted_logits\n\n    def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON\n        TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!\n\n        RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE.\n        SEE convert_predicted_logits_to_segmentation_with_correct_shape\n        \"\"\"\n        prediction = None\n\n        if not self.use_openvino:\n            for params in self.list_of_parameters:\n\n                # messing with state dict names...\n                if not isinstance(self.network, OptimizedModule):\n                    self.network.load_state_dict(params)\n                else:\n                    self.network._orig_mod.load_state_dict(params)\n\n                # why not leave prediction on device if perform_everything_on_device? Because this may cause the\n                # second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than\n                # this actually saves computation time\n                if prediction is None:\n                    prediction = self.predict_sliding_window_return_logits(data).to('cpu')\n                else:\n                    prediction += self.predict_sliding_window_return_logits(data).to('cpu')\n\n            if len(self.list_of_parameters) > 1:\n                prediction /= len(self.list_of_parameters)\n\n        else:\n            if prediction is None:\n                prediction = self.predict_sliding_window_return_logits(data)\n            else:\n                prediction += self.predict_sliding_window_return_logits(data)\n\n        if self.verbose: print('Prediction done')\n        return prediction\n\n    @torch.inference_mode()\n    def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \\\n            -> Union[np.ndarray, torch.Tensor]:\n        assert isinstance(input_image, torch.Tensor)\n        if self.device not in [torch.device('cpu'), 'cpu']:\n            self.network = self.network.to(self.device)\n            self.network.eval()\n\n        empty_cache(self.device)\n\n        # Autocast can be annoying\n        # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)\n        # and needs to be disabled.\n        # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False\n        # is set. Whyyyyyyy. (this is why we don't make use of enabled=False)\n        # So autocast will only be active if we have a cuda device.\n        with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():\n            assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'\n\n            if self.verbose: \n                print(f'Input shape: {input_image.shape}')\n                print(\"step_size:\", self.tile_step_size)\n                print(\"mirror_axes:\", self.allowed_mirroring_axes if self.use_mirroring else None)\n\n            # if input_image is smaller than tile_size we need to pad it to tile_size.\n            data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,\n                                                       'constant', {'value': 0}, True,\n                                                       None)\n\n            slicers = self._internal_get_sliding_window_slicers(data.shape[1:])\n\n            if self.perform_everything_on_device and self.device != 'cpu':\n                # we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device\n                try:\n                    predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,\n                                                                                           self.perform_everything_on_device)\n                except RuntimeError:\n                    print(\n                        'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')\n                    empty_cache(self.device)\n                    predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)\n            else:\n                predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,\n                                                                                       self.perform_everything_on_device)\n\n            empty_cache(self.device)\n            # revert padding\n            predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]\n        return predicted_logits\n\n    def convert_predicted_logits_to_segmentation_with_correct_shape(self, predicted_logits: Union[torch.Tensor, np.ndarray],\n                                                                    plans_manager: PlansManager,\n                                                                    configuration_manager: ConfigurationManager,\n                                                                    label_manager: LabelManager,\n                                                                    properties_dict: dict,\n                                                                    return_probabilities: bool = False,\n                                                                    num_threads_torch: int = default_num_processes):\n\n        # resample to original shape\n        current_spacing = configuration_manager.spacing if \\\n            len(configuration_manager.spacing) == \\\n            len(properties_dict['shape_after_cropping_and_before_resampling']) else \\\n            [properties_dict['spacing'][0], *configuration_manager.spacing]\n        predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits,\n                                                properties_dict['shape_after_cropping_and_before_resampling'],\n                                                current_spacing,\n                                                properties_dict['spacing'])\n        segmentation = predicted_logits.argmax(0)\n        del predicted_logits\n\n        # segmentation may be torch.Tensor but we continue with numpy\n        if isinstance(segmentation, torch.Tensor):\n            segmentation = segmentation.cpu().numpy()\n\n        # put segmentation in bbox (revert cropping)\n        segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'],\n                                                  dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16)\n        slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping'])\n        segmentation_reverted_cropping[slicer] = segmentation\n        del segmentation\n\n        # revert transpose\n        segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward)\n        return segmentation_reverted_cropping\n\n    def export_prediction_from_logits(self, predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict,\n                                      configuration_manager: ConfigurationManager,\n                                      plans_manager: PlansManager,\n                                      dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str,\n                                      save_probabilities: bool = False):\n\n        if isinstance(dataset_json_dict_or_file, str):\n            dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)\n\n        label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)\n        ret = self.convert_predicted_logits_to_segmentation_with_correct_shape(\n            predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict,\n            return_probabilities=save_probabilities\n        )\n        del predicted_array_or_file\n\n        segmentation_final = ret\n\n        rw = plans_manager.image_reader_writer_class()\n        rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'],\n                     properties_dict)\n\n    def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,\n                                 segmentation_previous_stage: np.ndarray = None,\n                                 output_file_truncated: str = None,\n                                 save_or_return_probabilities: bool = False):\n        \"\"\"\n        WARNING: SLOW. ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON.\n\n\n        input_image: Make sure to load the image in the way nnU-Net expects! nnU-Net is trained on a certain axis\n                     ordering which cannot be disturbed in inference,\n                     otherwise you will get bad results. The easiest way to achieve that is to use the same I/O class\n                     for loading images as was used during nnU-Net preprocessing! You can find that class in your\n                     plans.json file under the key \"image_reader_writer\". If you decide to freestyle, know that the\n                     default axis ordering for medical images is the one from SimpleITK. If you load with nibabel,\n                     you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]!\n        image_properties must only have a 'spacing' key!\n        \"\"\"\n        ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties],\n                                       [output_file_truncated],\n                                       self.plans_manager, self.dataset_json, self.configuration_manager,\n                                       num_threads_in_multithreaded=1, verbose=self.verbose)\n        if self.verbose:\n            print('preprocessing')\n        dct = next(ppa)\n\n        if self.verbose:\n            print('predicting')\n        predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu()\n\n        if self.verbose:\n            print('resampling to original shape')\n        self.export_prediction_from_logits(predicted_logits, dct['data_properties'], self.configuration_manager,\n                                        self.plans_manager, self.dataset_json, output_file_truncated,\n                                        save_or_return_probabilities)\n\n\ndef predict_flare(input_dir, output_dir, model_folder, save_model):\n    input_dir = Path(input_dir)\n    output_dir = Path(output_dir)\n    output_dir.mkdir(exist_ok=True, parents=True)\n    input_files = list(input_dir.glob(\"*.nii.gz\"))\n    output_files = [str(output_dir / f.name[:-12]) for f in input_files]\n    for input_file, output_file in zip(input_files, output_files):\n        print(f\"Predicting {input_file.name}\")\n        start = time()\n        predictor = FlarePredictor(tile_step_size=0.5, use_mirroring=False, device=torch.device(\"cpu\"))\n        predictor.initialize_from_trained_model_folder(model_folder, (\"all\",), save_model=save_model)\n        rw = predictor.plans_manager.image_reader_writer_class()\n        image, props = rw.read_images([input_file,])\n        _ = predictor.predict_single_npy_array(image, props, None, output_file, False)\n        print(f\"Prediction time: {time() - start:.2f}s\")\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-i\", \"--input\", default=\"/workspace/inputs\")\n    parser.add_argument(\"-o\", \"--output\", default=\"/workspace/outputs\")\n    parser.add_argument(\"-m\", \"--model\", default=\"/opt/app/_trained_model\")\n    parser.add_argument(\"-save_model\", action=\"store_true\")\n    args = parser.parse_args()\n    predict_flare(args.input, args.output, args.model, args.save_model)"
  },
  {
    "path": "documentation/competitions/FLARE24/Task_2/readme.md",
    "content": "Authors: \\\nYannick Kirchhoff*, Ashis Ravindran*, Maximilian Rouven Rokuss, Benjamin Hamm, Constantin Ulrich, Klaus Maier-Hein<sup>&#8224;</sup>, Fabian Isensee<sup>&#8224;</sup>\n\n*: equal contribution \\\n&#8224;: equal contribution\n\n# Introduction\n\nThis document describes our contribution to [Task 2 of the FLARE24 Challenge](https://www.codabench.org/competitions/2320/).\nOur model is basically is a default nnU-Net with a custom low resolution setting and OpenVINO optimizations for faster CPU inference.\n\n# Experiment Planning and Preprocessing\n\nBring the downloaded data into the [nnU-Net format](../../../nnUNet/documentation/dataset_format.md) and add the dataset.json file as given here:\n\n```json\n{\n    \"name\": \"Dataset311_FLARE24Task2_labeled\",\n    \"description\": \"Abdominal Organ Segmentation\",\n    \"labels\": {\n        \"background\": 0,\n        \"liver\": 1,\n        \"right kidney\": 2,\n        \"spleen\": 3,\n        \"pancreas\": 4,\n        \"aorta\": 5,\n        \"ivc\": 6,\n        \"rag\": 7,\n        \"lag\": 8,\n        \"gallbladder\": 9,\n        \"esophagus\": 10,\n        \"stomach\": 11,\n        \"duodenum\": 12,\n        \"left kidney\": 13\n    },\n    \"file_ending\": \".nii.gz\",\n    \"channel_names\": {\n        \"0\": \"CT\"\n    },\n    \"overwrite_image_reader_writer\": \"NibabelIOWithReorient\",\n    \"numTraining\": 50\n}\n```\n\nAfterwards you can run the default nnU-Net planning and preprocessing\n\n```bash\nnnUNetv2_plan_and_preprocess -d 311 -c 3d_fullres\n```\n\n## Edit the plans files\n\nThe generated `nnUNetPlans.json` file needs to be edited to incorporate the custom low resolution setting.\n\n```json\n        \"3d_halfres\": {\n            \"inherits_from\": \"3d_fullres\",\n            \"data_identifier\": \"nnUNetPlans_3d_halfres\",\n            \"spacing\": [\n                5,\n                1.6,\n                1.6\n            ]\n        },\n        \"3d_halfiso\": {\n            \"inherits_from\": \"3d_fullres\",\n            \"data_identifier\": \"nnUNetPlans_3d_halfiso\",\n            \"spacing\": [\n                2.5,\n                2.5,\n                2.5\n            ]\n        },\n```\n\n`3d_halfres` is a configuration with exactly half resolution, used as an ablation of our submission, `3d_halfiso` is the isotropic configuration we submitted as a final solution.\n\n# Model training\n\nRun one of the following commands to train the respective configurations. `3d_halfiso` yielded significantly better results in our experiments as well as on the final test set and is the recommended configuration.\n\n```bash\nnnUNetv2_train 311 3d_halfres all\n\nnnUNetv2_train 311 3d_halfiso all\n```\n\n# Inference\n\nOur inference is optimized for efficient single scan prediction. For best performance, we strongly recommend running inference using the default `nnUNetv2_predict` command!\n\nInference using the provided script requires OpenVINO, which can easily be installed via\n\n```bash\npip install openvino\n```\n\nTo run inference simply run the following commands. `model_folder` is the folder containing the training results, i.e. for example `nnUNetTrainer__nnUNetPlans__3d_halfiso`. `-save_model` needs to be set to precompile the model once using OpenVINO. If not precompiled model exists, the inference script will fail!\n\n```bash\npython inference_flare_task1.py -i input_folder -o output_folder -m model_folder [-save_model]\n```"
  },
  {
    "path": "documentation/competitions/FLARE24/__init__.py",
    "content": ""
  },
  {
    "path": "documentation/competitions/Toothfairy2/__init__.py",
    "content": ""
  },
  {
    "path": "documentation/competitions/Toothfairy2/inference_script_semseg_only_customInf2.py",
    "content": "import argparse\nimport gc\nimport os\nfrom pathlib import Path\nfrom queue import Queue\nfrom threading import Thread\nfrom typing import Union, Tuple\n\nimport nnunetv2\nimport numpy as np\nimport torch\nfrom acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice\nfrom acvl_utils.cropping_and_padding.padding import pad_nd_image\nfrom batchgenerators.utilities.file_and_folder_operations import load_json, join\nfrom nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO\nfrom nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\nfrom nnunetv2.inference.sliding_window_prediction import compute_gaussian\nfrom nnunetv2.utilities.find_class_by_name import recursive_find_python_class\nfrom nnunetv2.utilities.helpers import empty_cache, dummy_context\nfrom nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager\nfrom torch._dynamo import OptimizedModule\nfrom torch.backends import cudnn\nfrom tqdm import tqdm\n\n\nclass CustomPredictor(nnUNetPredictor):\n    def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,\n                                 segmentation_previous_stage: np.ndarray = None):\n        torch.set_num_threads(7)\n        with torch.no_grad():\n            self.network = self.network.to(self.device)\n            self.network.eval()\n\n            if self.verbose:\n                print('preprocessing')\n            preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose)\n            data, _, image_properties = preprocessor.run_case_npy(input_image, None, image_properties,\n                                                self.plans_manager,\n                                                self.configuration_manager,\n                                                self.dataset_json)\n\n            data = torch.from_numpy(data)\n            del input_image\n            if self.verbose:\n                print('predicting')\n\n            predicted_logits = self.predict_preprocessed_image(data)\n\n            if self.verbose: print('Prediction done')\n\n            segmentation = self.convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits,\n                                                                                            image_properties)\n        return segmentation\n\n    def initialize_from_trained_model_folder(self, model_training_output_dir: str,\n                                             use_folds: Union[Tuple[Union[int, str]], None],\n                                             checkpoint_name: str = 'checkpoint_final.pth'):\n        \"\"\"\n        This is used when making predictions with a trained model\n        \"\"\"\n        if use_folds is None:\n            use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name)\n\n        dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))\n        plans = load_json(join(model_training_output_dir, 'plans.json'))\n        plans_manager = PlansManager(plans)\n\n        if isinstance(use_folds, str):\n            use_folds = [use_folds]\n\n        parameters = []\n        for i, f in enumerate(use_folds):\n            f = int(f) if f != 'all' else f\n            checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),\n                                    map_location=torch.device('cpu'), weights_only=False)\n            if i == 0:\n                trainer_name = checkpoint['trainer_name']\n                configuration_name = checkpoint['init_args']['configuration']\n                inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \\\n                    'inference_allowed_mirroring_axes' in checkpoint.keys() else None\n\n            parameters.append(join(model_training_output_dir, f'fold_{f}', checkpoint_name))\n\n        configuration_manager = plans_manager.get_configuration(configuration_name)\n        # restore network\n        num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)\n        trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], \"training\", \"nnUNetTrainer\"),\n                                                    trainer_name, 'nnunetv2.training.nnUNetTrainer')\n        if trainer_class is None:\n            raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '\n                               f'Please place it there (in any .py file)!')\n        network = trainer_class.build_network_architecture(\n            configuration_manager.network_arch_class_name,\n            configuration_manager.network_arch_init_kwargs,\n            configuration_manager.network_arch_init_kwargs_req_import,\n            num_input_channels,\n            plans_manager.get_label_manager(dataset_json).num_segmentation_heads,\n            enable_deep_supervision=False\n        )\n\n        self.plans_manager = plans_manager\n        self.configuration_manager = configuration_manager\n        self.list_of_parameters = parameters\n        self.network = network\n        self.dataset_json = dataset_json\n        self.trainer_name = trainer_name\n        self.allowed_mirroring_axes = inference_allowed_mirroring_axes\n        self.label_manager = plans_manager.get_label_manager(dataset_json)\n        if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \\\n                and not isinstance(self.network, OptimizedModule):\n            print('Using torch.compile')\n            self.network = torch.compile(self.network)\n\n    @torch.inference_mode(mode=True)\n    def predict_preprocessed_image(self, image):\n        empty_cache(self.device)\n        data_device = torch.device('cpu')\n        predicted_logits_device = torch.device('cpu')\n        gaussian_device = torch.device('cpu')\n        compute_device = torch.device('cuda:0')\n\n        data, slicer_revert_padding = pad_nd_image(image, self.configuration_manager.patch_size,\n                                                   'constant', {'value': 0}, True,\n                                                   None)\n        del image\n\n        slicers = self._internal_get_sliding_window_slicers(data.shape[1:])\n\n        empty_cache(self.device)\n\n        data = data.to(data_device)\n        predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),\n                                       dtype=torch.half,\n                                       device=predicted_logits_device)\n        gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,\n                                    value_scaling_factor=10,\n                                    device=gaussian_device, dtype=torch.float16)\n\n        if not self.allow_tqdm and self.verbose:\n            print(f'running prediction: {len(slicers)} steps')\n\n        for p in self.list_of_parameters:\n            # network weights have to be updated outside autocast!\n            # we are loading parameters on demand instead of loading them upfront. This reduces memory footprint a lot.\n            # each set of parameters is only used once on the test set (one image) so run time wise this is almost the\n            # same\n            self.network.load_state_dict(torch.load(p, map_location=compute_device)['network_weights'], weights_only=False)\n            with torch.autocast(self.device.type, enabled=True):\n                for sl in tqdm(slicers, disable=not self.allow_tqdm):\n                    pred = self._internal_maybe_mirror_and_predict(data[sl][None].to(compute_device))[0].to(\n                        predicted_logits_device)\n                    pred /= (pred.max() / 100)\n                    predicted_logits[sl] += (pred * gaussian)\n                del pred\n        empty_cache(self.device)\n        return predicted_logits\n\n    def convert_predicted_logits_to_segmentation_with_correct_shape(self, predicted_logits, props):\n        old = torch.get_num_threads()\n        torch.set_num_threads(7)\n\n        # resample to original shape\n        spacing_transposed = [props['spacing'][i] for i in self.plans_manager.transpose_forward]\n        current_spacing = self.configuration_manager.spacing if \\\n            len(self.configuration_manager.spacing) == \\\n            len(props['shape_after_cropping_and_before_resampling']) else \\\n            [spacing_transposed[0], *self.configuration_manager.spacing]\n        predicted_logits = self.configuration_manager.resampling_fn_probabilities(predicted_logits,\n                                                                                  props[\n                                                                                      'shape_after_cropping_and_before_resampling'],\n                                                                                  current_spacing,\n                                                                                  [props['spacing'][i] for i in\n                                                                                   self.plans_manager.transpose_forward])\n\n        segmentation = None\n        pp = None\n        try:\n            with torch.no_grad():\n                pp = predicted_logits.to('cuda:0')\n                segmentation = pp.argmax(0).cpu()\n                del pp\n        except RuntimeError:\n            del segmentation, pp\n            torch.cuda.empty_cache()\n            segmentation = predicted_logits.argmax(0)\n        del predicted_logits\n\n        # segmentation may be torch.Tensor but we continue with numpy\n        if isinstance(segmentation, torch.Tensor):\n            segmentation = segmentation.cpu().numpy()\n\n        # put segmentation in bbox (revert cropping)\n        segmentation_reverted_cropping = np.zeros(props['shape_before_cropping'],\n                                                  dtype=np.uint8 if len(\n                                                      self.label_manager.foreground_labels) < 255 else np.uint16)\n        slicer = bounding_box_to_slice(props['bbox_used_for_cropping'])\n        segmentation_reverted_cropping[slicer] = segmentation\n        del segmentation\n\n        # revert transpose\n        segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(self.plans_manager.transpose_backward)\n        torch.set_num_threads(old)\n        return segmentation_reverted_cropping\n\n\ndef predict_semseg(im, prop, semseg_trained_model, semseg_folds):\n    # initialize predictors\n    pred_semseg = CustomPredictor(\n        tile_step_size=0.5,\n        use_mirroring=True,\n        use_gaussian=True,\n        perform_everything_on_device=False,\n        allow_tqdm=True\n    )\n    pred_semseg.initialize_from_trained_model_folder(\n        semseg_trained_model,\n        use_folds=semseg_folds,\n        checkpoint_name='checkpoint_final.pth'\n    )\n\n    semseg_pred = pred_semseg.predict_single_npy_array(\n        im, prop, None\n    )\n    torch.cuda.empty_cache()\n    gc.collect()\n    return semseg_pred\n\n\ndef map_labels_to_toothfairy(predicted_seg: np.ndarray) -> np.ndarray:\n    # Create an array that maps the labels directly\n    max_label = 42\n    mapping = np.arange(max_label + 1)\n\n    # Define the specific remapping\n    remapping = {19: 21, 20: 22, 21: 23, 22: 24, 23: 25, 24: 26, 25: 27, 26: 28,\n                 27: 31, 28: 32, 29: 33, 30: 34, 31: 35, 32: 36, 33: 37, 34: 38,\n                 35: 41, 36: 42, 37: 43, 38: 44, 39: 45, 40: 46, 41: 47, 42: 48}\n\n    # Apply the remapping\n    for k, v in remapping.items():\n        mapping[k] = v\n\n    return mapping[predicted_seg]\n\n\ndef postprocess(prediction_npy, vol_per_voxel, verbose: bool = False):\n    cutoffs = {1: 0.0,\n               2: 78411.5,\n               3: 0.0,\n               4: 0.0,\n               5: 2800.0,\n               6: 1216.5,\n               7: 0.0,\n               8: 6222.0,\n               9: 1573.0,\n               10: 946.0,\n               11: 0.0,\n               12: 6783.5,\n               13: 9469.5,\n               14: 0.0,\n               15: 2260.0,\n               16: 3566.0,\n               17: 6321.0,\n               18: 4221.5,\n               19: 5829.0,\n               20: 0.0,\n               21: 0.0,\n               22: 468.0,\n               23: 1555.0,\n               24: 1291.5,\n               25: 2834.5,\n               26: 584.5,\n               27: 0.0,\n               28: 0.0,\n               29: 0.0,\n               30: 0.0,\n               31: 1935.5,\n               32: 0.0,\n               33: 0.0,\n               34: 6140.0,\n               35: 0.0,\n               36: 0.0,\n               37: 0.0,\n               38: 2710.0,\n               39: 0.0,\n               40: 0.0,\n               41: 0.0,\n               42: 970.0}\n\n    vol_per_voxel_cutoffs = 0.3 * 0.3 * 0.3\n    for c in cutoffs.keys():\n        co = cutoffs[c]\n        if co > 0:\n            mask = prediction_npy == c\n            pred_vol = np.sum(mask) * vol_per_voxel\n            if 0 < pred_vol < (co * vol_per_voxel_cutoffs):\n                prediction_npy[mask] = 0\n                if verbose:\n                    print(\n                        f'removed label {c} because predicted volume of {pred_vol} is less than the cutoff {co * vol_per_voxel_cutoffs}')\n    return prediction_npy\n\n\nif __name__ == '__main__':\n    os.environ['nnUNet_compile'] = 'f'\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-i', '--input_folder', type=Path, default=\"/input/images/cbct/\")\n    parser.add_argument('-o', '--output_folder', type=Path, default=\"/output/images/oral-pharyngeal-segmentation/\")\n    parser.add_argument('-sem_mod', '--semseg_trained_model', type=str,\n                        default=\"/opt/app/_trained_model/semseg_trained_model\")\n    parser.add_argument('--semseg_folds', type=str, nargs='+', default=[0, 1])\n    args = parser.parse_args()\n\n    args.output_folder.mkdir(exist_ok=True, parents=True)\n\n    semseg_folds = [i if i == 'all' else int(i) for i in args.semseg_folds]\n    semseg_trained_model = args.semseg_trained_model\n\n    rw = SimpleITKIO()\n\n    input_files = list(args.input_folder.glob('*.nii.gz')) + list(args.input_folder.glob('*.mha'))\n\n    for input_fname in input_files:\n        output_fname = args.output_folder / input_fname.name\n\n        # we start with the instance seg because we can then start converting that while semseg is being predicted\n        # load test image\n        im, prop = rw.read_images([input_fname])\n\n        with torch.no_grad():\n            semseg_pred = predict_semseg(im, prop, semseg_trained_model, semseg_folds)\n            torch.cuda.empty_cache()\n            gc.collect()\n\n        # now postprocess\n        semseg_pred = postprocess(semseg_pred, np.prod(prop['spacing']), True)\n\n        semseg_pred = map_labels_to_toothfairy(semseg_pred)\n\n        # now save\n        rw.write_seg(semseg_pred, output_fname, prop)\n"
  },
  {
    "path": "documentation/competitions/Toothfairy2/readme.md",
    "content": "Authors: \\\nFabian Isensee*, Yannick Kirchhoff*, Lars Kraemer, Max Rokuss, Constantin Ulrich, Klaus H. Maier-Hein \n\n*: equal contribution\n\nAuthor Affiliations:\\\nDivision of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg \\\nHelmholtz Imaging\n\n# Introduction\n\nThis document describes our submission to the [Toothfairy2 Challenge](https://toothfairy2.grand-challenge.org/toothfairy2/). \nOur model is essentially a nnU-Net ResEnc L with the patch size upscaled to 160x320x320 pixels. We disable left/right \nmirroring and train for 1500 instead of the standard 1000 epochs. Training was either done on 2xA100 40GB or one GH200 96GB.\n\n# Dataset Conversion\n\n# Experiment Planning and Preprocessing\nAdapt and run the [dataset conversion script](../../../nnunetv2/dataset_conversion/Dataset119_ToothFairy2_All.py). \nThis script just converts the mha files to nifti (smaller file size) and removes the unused label ids.\n\n## Extract fingerprint:\n`nnUNetv2_extract_fingerprint -d 119 -np 48`\n\n## Run planning:\n`nnUNetv2_plan_experiment -d 119 -pl nnUNetPlannerResEncL_torchres`\n\nThis planner not only uses the ResEncL configuration but also replaces the default resampling scheme with one that is \nfaster (but less precise). Since all images in the challenge (train and test) should already have 0.3x0.3x0.3 spacing \nresampling is not required. This is just here as a safety measure. The speed is needed at inference time because grand \nchallenge imposes a limit of 10 minutes per case.\n\n## Edit the plans files\nAdd the following configuration to the generated plans file:\n\n```json\n        \"3d_fullres_torchres_ps160x320x320_bs2\": {\n            \"inherits_from\": \"3d_fullres\",\n            \"data_identifier\": \"nnUNetPlans_3d_fullres_torchres_ctnorm\",\n            \"patch_size\": [\n                160,\n                320,\n                320\n            ],\n            \"normalization_schemes\": [\n                \"CTNormalization\"\n            ],\n            \"architecture\": {\n                \"network_class_name\": \"dynamic_network_architectures.architectures.unet.ResidualEncoderUNet\",\n                \"arch_kwargs\": {\n                    \"n_stages\": 7,\n                    \"features_per_stage\": [\n                        32,\n                        64,\n                        128,\n                        256,\n                        320,\n                        320,\n                        320\n                    ],\n                    \"conv_op\": \"torch.nn.modules.conv.Conv3d\",\n                    \"kernel_sizes\": [\n                        [\n                            3,\n                            3,\n                            3\n                        ],\n                        [\n                            3,\n                            3,\n                            3\n                        ],\n                        [\n                            3,\n                            3,\n                            3\n                        ],\n                        [\n                            3,\n                            3,\n                            3\n                        ],\n                        [\n                            3,\n                            3,\n                            3\n                        ],\n                        [\n                            3,\n                            3,\n                            3\n                        ],\n                        [\n                            3,\n                            3,\n                            3\n                        ]\n                    ],\n                    \"strides\": [\n                        [\n                            1,\n                            1,\n                            1\n                        ],\n                        [\n                            2,\n                            2,\n                            2\n                        ],\n                        [\n                            2,\n                            2,\n                            2\n                        ],\n                        [\n                            2,\n                            2,\n                            2\n                        ],\n                        [\n                            2,\n                            2,\n                            2\n                        ],\n                        [\n                            2,\n                            2,\n                            2\n                        ],\n                        [\n                            1,\n                            2,\n                            2\n                        ]\n                    ],\n                    \"n_blocks_per_stage\": [\n                        1,\n                        3,\n                        4,\n                        6,\n                        6,\n                        6,\n                        6\n                    ],\n                    \"n_conv_per_stage_decoder\": [\n                        1,\n                        1,\n                        1,\n                        1,\n                        1,\n                        1\n                    ],\n                    \"conv_bias\": true,\n                    \"norm_op\": \"torch.nn.modules.instancenorm.InstanceNorm3d\",\n                    \"norm_op_kwargs\": {\n                        \"eps\": 1e-05,\n                        \"affine\": true\n                    },\n                    \"dropout_op\": null,\n                    \"dropout_op_kwargs\": null,\n                    \"nonlin\": \"torch.nn.LeakyReLU\",\n                    \"nonlin_kwargs\": {\n                        \"inplace\": true\n                    }\n                },\n                \"_kw_requires_import\": [\n                    \"conv_op\",\n                    \"norm_op\",\n                    \"dropout_op\",\n                    \"nonlin\"\n                ]\n            }            \n        }\n```\nAside from changing the patch size this makes the architecture one stage deeper (one more pooling + res blocks), enabling\nit to make effective use of the larger input\n\n# Preprocessing\n`nnUNetv2_preprocess -d 119 -c 3d_fullres_torchres_ps160x320x320_bs2 -plans_name nnUNetResEncUNetLPlans_torchres -np 48`\n\n# Training\nWe train two models on all training cases:\n\n```bash\nnnUNetv2_train 119 3d_fullres_torchres_ps160x320x320_bs2 all -p nnUNetResEncUNetLPlans_torchres -tr nnUNetTrainer_onlyMirror01_1500ep\nnnUNet_results=${nnUNet_results}_2 nnUNetv2_train 119 3d_fullres_torchres_ps160x320x320_bs2 all -p nnUNetResEncUNetLPlans_torchres -tr nnUNetTrainer_onlyMirror01_1500ep\n```\nModels are trained from scratch.\n\nNote how in the second line we overwrite the nnUNet_results variable in order to be able to train the same model twice without overwriting the results\n\nWe recommend to increase the number of processes used for data augmentation. Otherwise you can run into CPU bottlenecks.\nUse `export nnUNet_n_proc_DA=32` or higher (if your system permits!).\n\n# Inference\nWe ensemble the two models from above. On a technical level we copy the two fold_all folders into one training output \ndirectory and rename them to fold_0 and fold_1. This lets us use nnU-Net's cross-validation ensembling strategy which \nis more computationally efficient (needed for time limit on grand-challenge.org).\n\nRun inference with the [inference script](inference_script_semseg_only_customInf2.py)\n\n# Postprocessing\nIf the prediction of a class on some test case is smaller than the corresponding cutoff size then it is removed \n(replaced with background).\n\nCutoff values were optimized using a five-fold cross-validation on the Toothfairy2 training data. We optimize HD95 and Dice separately. \nThe final cutoff for each class is then the smaller value between the two metrics. You can find our volume cutoffs in the inference \nscript as part of our `postprocess` function.    "
  },
  {
    "path": "documentation/competitions/__init__.py",
    "content": ""
  },
  {
    "path": "documentation/convert_msd_dataset.md",
    "content": "Use `nnUNetv2_convert_MSD_dataset`.\n\nRead `nnUNetv2_convert_MSD_dataset -h` for usage instructions."
  },
  {
    "path": "documentation/dataset_format.md",
    "content": "# nnU-Net dataset format\nThe only way to bring your data into nnU-Net is by storing it in a specific format. Due to nnU-Net's roots in the\n[Medical Segmentation Decathlon](http://medicaldecathlon.com/) (MSD), its dataset is heavily inspired but has since \ndiverged (see also [here](#how-to-use-decathlon-datasets)) from the format used in the MSD.\n\nDatasets consist of three components: raw images, corresponding segmentation maps and a dataset.json file specifying \nsome metadata. \n\nIf you are migrating from nnU-Net v1, read [this](#how-to-use-nnu-net-v1-tasks) to convert your existing Tasks.\n\n\n## What do training cases look like?\nEach training case is associated with an identifier = a unique name for that case. This identifier is used by nnU-Net to \nconnect images with the correct segmentation.\n\nA training case consists of images and their corresponding segmentation. \n\n**Images** is plural because nnU-Net supports arbitrarily many input channels. In order to be as flexible as possible, \nnnU-net requires each input channel to be stored in a separate image (with the sole exception being RGB natural \nimages). So these images could for example be a T1 and a T2 MRI (or whatever else you want). The different input \nchannels MUST have the same geometry (same shape, spacing (if applicable) etc.) and\nmust be co-registered (if applicable). Input channels are identified by nnU-Net by their FILE_ENDING: a four-digit integer at the end \nof the filename. Image files must therefore follow the following naming convention: {CASE_IDENTIFIER}_{XXXX}.{FILE_ENDING}. \nHereby, XXXX is the 4-digit modality/channel identifier (should be unique for each modality/channel, e.g., “0000” for T1, “0001” for \nT2 MRI, …) and FILE_ENDING is the file extension used by your image format (.png, .nii.gz, ...). See below for concrete examples.\nThe dataset.json file connects channel names with the channel identifiers in the 'channel_names' key (see below for details).\n\nSide note: Typically, each channel/modality needs to be stored in a separate file and is accessed with the XXXX channel identifier. \nException are natural images (RGB; .png) where the three color channels can all be stored in one file (see the \n[road segmentation](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py) dataset as an example). \n\n**Segmentations** must share the same geometry with their corresponding images (same shape etc.). Segmentations are \ninteger maps with each value representing a semantic class. The background must be 0. If there is no background, then \ndo not use the label 0 for something else! Integer values of your semantic classes must be consecutive (0, 1, 2, 3, \n...). Of course, not all labels have to be present in each training case. Segmentations are saved as {CASE_IDENTIFER}.{FILE_ENDING} .\n\nWithin a training case, all image geometries (input channels, corresponding segmentation) must match. Between training \ncases, they can of course differ. nnU-Net takes care of that.\n\nImportant: The input channels must be consistent! Concretely, **all images need the same input channels in the same \norder and all input channels have to be present every time**. This is also true for inference!\n\n\n## Supported file formats\nnnU-Net expects the same file format for images and segmentations! These will also be used for inference. For now, it \nis thus not possible to train .png and then run inference on .jpg.\n\nOne big change in nnU-Net V2 is the support of multiple input file types. Gone are the days of converting everything to .nii.gz!\nThis is implemented by abstracting the input and output of images + segmentations through `BaseReaderWriter`. nnU-Net \ncomes with a broad collection of Readers+Writers and you can even add your own to support your data format! \nSee [here](../nnunetv2/imageio/readme.md).\n\nAs a nice bonus, nnU-Net now also natively supports 2D input images and you no longer have to mess around with \nconversions to pseudo 3D niftis. Yuck. That was disgusting.\n\nNote that internally (for storing and accessing preprocessed images) nnU-Net will use its own file format, irrespective \nof what the raw data was provided in! This is for performance reasons.\n\n\nBy default, the following file formats are supported:\n\n- NaturalImage2DIO: .png, .bmp, .tif\n- NibabelIO: .nii.gz, .nrrd, .mha\n- NibabelIOWithReorient: .nii.gz, .nrrd, .mha. This reader will reorient images to RAS!\n- SimpleITKIO: .nii.gz, .nrrd, .mha\n- Tiff3DIO: .tif, .tiff. 3D tif images! Since TIF does not have a standardized way of storing spacing information, \nnnU-Net expects each TIF file to be accompanied by an identically named .json file that contains this information (see\n[here](#datasetjson)).\n\nThe file extension lists are not exhaustive and depend on what the backend supports. For example, nibabel and SimpleITK \nsupport more than the three given here. The file endings given here are just the ones we tested!\n\nIMPORTANT: nnU-Net can only be used with file formats that use lossless (or no) compression! Because the file \nformat is defined for an entire dataset (and not separately for images and segmentations, this could be a todo for \nthe future), we must ensure that there are no compression artifacts that destroy the segmentation maps. So no .jpg and \nthe likes! \n\n## Dataset folder structure\nDatasets must be located in the `nnUNet_raw` folder (which you either define when installing nnU-Net or export/set every \ntime you intend to run nnU-Net commands!).\nEach segmentation dataset is stored as a separate 'Dataset'. Datasets are associated with a dataset ID, a three digit \ninteger, and a dataset name (which you can freely choose): For example, Dataset005_Prostate has 'Prostate' as dataset name and \nthe dataset id is 5. Datasets are stored in the `nnUNet_raw` folder like this:\n\n    nnUNet_raw/\n    ├── Dataset001_BrainTumour\n    ├── Dataset002_Heart\n    ├── Dataset003_Liver\n    ├── Dataset004_Hippocampus\n    ├── Dataset005_Prostate\n    ├── ...\n\nWithin each dataset folder, the following structure is expected:\n\n    Dataset001_BrainTumour/\n    ├── dataset.json\n    ├── imagesTr\n    ├── imagesTs  # optional\n    └── labelsTr\n\n\nWhen adding your custom dataset, take a look at the [dataset_conversion](../nnunetv2/dataset_conversion) folder and \npick an id that is not already taken. IDs 001-010 are for the Medical Segmentation Decathlon.\n\n- **imagesTr** contains the images belonging to the training cases. nnU-Net will perform pipeline configuration, training with \ncross-validation, as well as finding postprocessing and the best ensemble using this data. \n- **imagesTs** (optional) contains the images that belong to the test cases. nnU-Net does not use them! This could just \nbe a convenient location for you to store these images. Remnant of the Medical Segmentation Decathlon folder structure.\n- **labelsTr** contains the images with the ground truth segmentation maps for the training cases. \n- **dataset.json** contains metadata of the dataset.\n\nThe scheme introduced [above](#what-do-training-cases-look-like) results in the following folder structure. Given \nis an example for the first Dataset of the MSD: BrainTumour. This dataset hat four input channels: FLAIR (0000), \nT1w (0001), T1gd (0002) and T2w (0003). Note that the imagesTs folder is optional and does not have to be present.\n\n    nnUNet_raw/Dataset001_BrainTumour/\n    ├── dataset.json\n    ├── imagesTr\n    │   ├── BRATS_001_0000.nii.gz\n    │   ├── BRATS_001_0001.nii.gz\n    │   ├── BRATS_001_0002.nii.gz\n    │   ├── BRATS_001_0003.nii.gz\n    │   ├── BRATS_002_0000.nii.gz\n    │   ├── BRATS_002_0001.nii.gz\n    │   ├── BRATS_002_0002.nii.gz\n    │   ├── BRATS_002_0003.nii.gz\n    │   ├── ...\n    ├── imagesTs\n    │   ├── BRATS_485_0000.nii.gz\n    │   ├── BRATS_485_0001.nii.gz\n    │   ├── BRATS_485_0002.nii.gz\n    │   ├── BRATS_485_0003.nii.gz\n    │   ├── BRATS_486_0000.nii.gz\n    │   ├── BRATS_486_0001.nii.gz\n    │   ├── BRATS_486_0002.nii.gz\n    │   ├── BRATS_486_0003.nii.gz\n    │   ├── ...\n    └── labelsTr\n        ├── BRATS_001.nii.gz\n        ├── BRATS_002.nii.gz\n        ├── ...\n\nHere is another example of the second dataset of the MSD, which has only one input channel:\n\n    nnUNet_raw/Dataset002_Heart/\n    ├── dataset.json\n    ├── imagesTr\n    │   ├── la_003_0000.nii.gz\n    │   ├── la_004_0000.nii.gz\n    │   ├── ...\n    ├── imagesTs\n    │   ├── la_001_0000.nii.gz\n    │   ├── la_002_0000.nii.gz\n    │   ├── ...\n    └── labelsTr\n        ├── la_003.nii.gz\n        ├── la_004.nii.gz\n        ├── ...\n\nRemember: For each training case, all images must have the same geometry to ensure that their pixel arrays are aligned. Also \nmake sure that all your data is co-registered!\n\nSee also [dataset format inference](dataset_format_inference.md)!!\n\n## dataset.json\nThe dataset.json contains metadata that nnU-Net needs for training. We have greatly reduced the number of required \nfields since version 1!\n\nHere is what the dataset.json should look like at the example of the Dataset005_Prostate from the MSD:\n\n    { \n     \"channel_names\": {  # formerly modalities\n       \"0\": \"T2\", \n       \"1\": \"ADC\"\n     }, \n     \"labels\": {  # THIS IS DIFFERENT NOW!\n       \"background\": 0,\n       \"PZ\": 1,\n       \"TZ\": 2\n     }, \n     \"numTraining\": 32, \n     \"file_ending\": \".nii.gz\"\n     \"overwrite_image_reader_writer\": \"SimpleITKIO\"  # optional! If not provided nnU-Net will automatically determine the ReaderWriter\n     }\n\nThe channel_names determine the normalization used by nnU-Net. If a channel is marked as 'CT', then a global \nnormalization based on the intensities in the foreground pixels will be used. If it is something else, per-channel \nz-scoring will be used. Refer to the methods section in [our paper](https://www.nature.com/articles/s41592-020-01008-z) \nfor more details. nnU-Net v2 introduces a few more normalization schemes to \nchoose from and allows you to define your own, see [here](explanation_normalization.md) for more information. \n\nImportant changes relative to nnU-Net v1:\n- \"modality\" is now called \"channel_names\" to remove strong bias to medical images\n- labels are structured differently (name -> int instead of int -> name). This was needed to support [region-based training](region_based_training.md)\n- \"file_ending\" is added to support different input file types\n- \"overwrite_image_reader_writer\" optional! Can be used to specify a certain (custom) ReaderWriter class that should \nbe used with this dataset. If not provided, nnU-Net will automatically determine the ReaderWriter\n- \"regions_class_order\" only used in [region-based training](region_based_training.md)\n\nThere is a utility with which you can generate the dataset.json automatically. You can find it \n[here](../nnunetv2/dataset_conversion/generate_dataset_json.py). \nSee our examples in [dataset_conversion](../nnunetv2/dataset_conversion) for how to use it. And read its documentation!\n\nAs described above, a json file that contains spacing information is required for TIFF files.\nAn example for a 3D TIFF stack with units corresponding to 7.6 in x and y, 80 in z is:\n\n```\n{\n    \"spacing\": [7.6, 7.6, 80.0]\n}\n```\n\nWithin the dataset folder, this file (named `cell6.json` in this example) would be placed in the following folders:\n\n    nnUNet_raw/Dataset123_Foo/\n    ├── dataset.json\n    ├── imagesTr\n    │   ├── cell6.json\n    │   └── cell6_0000.tif\n    └── labelsTr\n        ├── cell6.json\n        └── cell6.tif\n\n\n## How to use nnU-Net v1 Tasks\nIf you are migrating from the old nnU-Net, convert your existing datasets with `nnUNetv2_convert_old_nnUNet_dataset`!\n\nExample for migrating a nnU-Net v1 Task:\n```bash\nnnUNetv2_convert_old_nnUNet_dataset /media/isensee/raw_data/nnUNet_raw_data_base/nnUNet_raw_data/Task027_ACDC Dataset027_ACDC \n```\nUse `nnUNetv2_convert_old_nnUNet_dataset -h` for detailed usage instructions.\n\n\n## How to use decathlon datasets\nSee [convert_msd_dataset.md](convert_msd_dataset.md)\n\n## How to use 2D data with nnU-Net\n2D is now natively supported (yay!). See [here](#supported-file-formats) as well as the example dataset in this \n[script](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py).\n\n\n## How to update an existing dataset\nWhen updating a dataset it is best practice to remove the preprocessed data in `nnUNet_preprocessed/DatasetXXX_NAME` \nto ensure a fresh start. Then replace the data in `nnUNet_raw` and rerun `nnUNetv2_plan_and_preprocess`. Optionally, \nalso remove the results from old trainings.\n\n# Example dataset conversion scripts\nIn the `dataset_conversion` folder (see [here](../nnunetv2/dataset_conversion)) are multiple example scripts for \nconverting datasets into nnU-Net format. These scripts cannot be run as they are (you need to open them and change \nsome paths) but they are excellent examples for you to learn how to convert your own datasets into nnU-Net format. \nJust pick the dataset that is closest to yours as a starting point.\nThe list of dataset conversion scripts is continually updated. If you find that some publicly available dataset is \nmissing, feel free to open a PR to add it!\n"
  },
  {
    "path": "documentation/dataset_format_inference.md",
    "content": "# Data format for Inference \nRead the documentation on the overall [data format](dataset_format.md) first!\n\nThe data format for inference must match the one used for the raw data (**specifically, the images must be in exactly \nthe same format as in the imagesTr folder**). As before, the filenames must start with a\nunique identifier, followed by a 4-digit modality identifier. Here is an example for two different datasets:\n\n1) Task005_Prostate:\n\n    This task has 2 modalities, so the files in the input folder must look like this:\n\n        input_folder\n        ├── prostate_03_0000.nii.gz\n        ├── prostate_03_0001.nii.gz\n        ├── prostate_05_0000.nii.gz\n        ├── prostate_05_0001.nii.gz\n        ├── prostate_08_0000.nii.gz\n        ├── prostate_08_0001.nii.gz\n        ├── ...\n\n    _0000 has to be the T2 image and _0001 has to be the ADC image (as specified by 'channel_names' in the \ndataset.json), exactly the same as was used for training.\n\n2) Task002_Heart:\n\n        imagesTs\n        ├── la_001_0000.nii.gz\n        ├── la_002_0000.nii.gz\n        ├── la_006_0000.nii.gz\n        ├── ...\n    \n    Task002 only has one modality, so each case only has one _0000.nii.gz file.\n  \n\nThe segmentations in the output folder will be named {CASE_IDENTIFIER}.nii.gz (omitting the modality identifier).\n\nRemember that the file format used for inference (.nii.gz in this example) must be the same as was used for training \n(and as was specified in 'file_ending' in the dataset.json)!\n   "
  },
  {
    "path": "documentation/explanation_logging.md",
    "content": "# Logging in nnU-Net v2\n\n## Introduction\n\nLogging in nnU-Net is intentionally simple and centralized in\n`nnunetv2/training/logging/nnunet_logger.py`.\n\nThe trainer talks to one object, `MetaLogger`, and `MetaLogger` fans out logs to:\n\n- `LocalLogger` (always enabled): the source of truth for training curves, checkpoint logging state, and `progress.png`\n- optional external loggers (currently `WandbLogger`)\n\nThis keeps training code clean while still allowing external tracking backends.\n\n## Default behaviour\n\nWithout any setup, nnU-Net uses only `LocalLogger`.\n\nPer epoch, it stores:\n\n- `mean_fg_dice` and `ema_fg_dice` (EMA is computed automatically)\n- `dice_per_class_or_region`\n- `train_losses`, `val_losses`\n- `lrs`\n- `epoch_start_timestamps`, `epoch_end_timestamps`\n\nFrom these values, `progress.png` is updated in the fold output folder.\nOn checkpoint save/load, the local logging state is also saved/restored, so curves continue correctly after resume.\n\n## How to enable W&B\n\n1. Install W&B:\n\n```bash\npip install wandb\n```\n\n2. Enable the backend via environment variables:\n\n```bash\nexport nnUNet_wandb_enabled=1\nexport nnUNet_wandb_project=nnunet\nexport nnUNet_wandb_mode=online   # or offline\n```\n\n3. Run training normally:\n\n```bash\nnnUNetv2_train DATASET_NAME_OR_ID 3d_fullres 0\n```\n\nNotes:\n\n- `nnUNet_wandb_enabled` accepts `0/1` and `false/true` (case-insensitive). Other values raise an error.\n- When resuming (`--c`), W&B resume metadata in `fold_x/wandb/latest-run` is reused and duplicate older steps are skipped.\n\n## How to integrate a custom logger\n\nAdd a new logger class with the same minimal interface used by `MetaLogger`:\n\n- `update_config(self, config: dict)`\n- `log(self, key, value, step: int)`\n- `log_summary(self, key, value)`\n\nExample skeleton:\n\n```python\nclass MyLogger:\n    def __init__(self, output_folder, resume):\n        self.output_folder = output_folder\n        self.resume = resume\n\n    def update_config(self, config: dict):\n        ...\n\n    def log(self, key, value, step: int):\n        ...\n\n    def log_summary(self, key, value):\n        ...\n```\n\nThen register it in `MetaLogger.__init__` (for example behind an env var switch), similar to how `WandbLogger` is added.\n\nImportant integration detail:\n\n- `MetaLogger.log(...)` always writes to `LocalLogger` first.\n- If you introduce a brand-new per-epoch key, also add that key to `LocalLogger.my_fantastic_logging`, otherwise the local assertion will fail.\n"
  },
  {
    "path": "documentation/explanation_normalization.md",
    "content": "# Intensity normalization in nnU-Net \n\nThe type of intensity normalization applied in nnU-Net can be controlled via the `channel_names` (former `modalities`)\nentry in the dataset.json. Just like the old nnU-Net, per-channel z-scoring as well as dataset-wide z-scoring based on \nforeground intensities are supported. However, there have been a few additions as well.\n\nReminder: The `channel_names` entry typically looks like this: \n\n    \"channel_names\": {\n        \"0\": \"T2\",\n        \"1\": \"ADC\"\n    },\n\nIt has as many entries as there are input channels for the given dataset.\n\nTo tell you a secret, nnU-Net does not really care what your channels are called. We just use this to determine what normalization\nscheme will be used for the given dataset. nnU-Net requires you to specify a normalization strategy for each of your input channels! \nIf you enter a channel name that is not in the following list, the default (`zscore`) will be used.\n\nHere is a list of currently available normalization schemes:\n\n- `CT`: Perform CT normalization. Specifically, collect intensity values from the foreground classes (all but the \nbackground and ignore) from all training cases, compute the mean, standard deviation as well as the 0.5 and \n99.5 percentile of the values. Then clip to the percentiles, followed by subtraction of the mean and division with the \nstandard deviation. The normalization that is applied is the same for each training case (for this input channel).\nThe values used by nnU-Net for normalization are stored in the `foreground_intensity_properties_per_channel` entry in the \ncorresponding plans file. This normalization is suitable for modalities presenting physical quantities such as CT \nimages and ADC maps.\n- `noNorm` : do not perform any normalization at all\n- `rescale_to_0_1`: rescale the intensities to [0, 1]\n- `rgb_to_0_1`: assumes uint8 inputs. Divides by 255 to rescale uint8 to [0, 1]\n- `zscore`/anything else: perform z-scoring (subtract mean and standard deviation) separately for each train case\n\n**Important:** The nnU-Net default is to perform 'CT' normalization for CT images and 'zscore' for everything else! If \nyou deviate from that path, make sure to benchmark whether that actually improves results! \n\n# How to implement custom normalization strategies?\n- Head over to nnunetv2/preprocessing/normalization\n- implement a new image normalization class by deriving from ImageNormalization\n- register it in nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py:channel_name_to_normalization_mapping. \nThis is where you specify a channel name that should be associated with it\n- use it by specifying the correct channel_name\n\nNormalization can only be applied to one channel at a time. There is currently no way of implementing a normalization scheme \nthat gets multiple channels as input to be used jointly!"
  },
  {
    "path": "documentation/explanation_plans_files.md",
    "content": "# Modifying the nnU-Net Configurations\n\nnnU-Net provides unprecedented out-of-the-box segmentation performance for essentially any dataset we have evaluated \nit on. That said, there is always room for improvements. A fool-proof strategy for squeezing out the last bit of \nperformance is to start with the default nnU-Net, and then further tune it manually to a concrete dataset at hand.\n**This guide is about changes to the nnU-Net configuration you can make via the plans files. It does not cover code \nextensions of nnU-Net. For that, take a look [here](extending_nnunet.md)**\n\nIn nnU-Net V2, plans files are SO MUCH MORE powerful than they were in v1. There are a lot more knobs that you can \nturn without resorting to hacky solutions or even having to touch the nnU-Net code at all! And as an added bonus: \nplans files are now also .json files and no longer require users to fiddle with pickle. Just open them in your text \neditor of choice!\n\nIf overwhelmed, look at our [Examples](#examples)!\n\n# plans.json structure\n\nPlans have global and local settings. Global settings are applied to all configurations in that plans file while \nlocal settings are attached to a specific configuration.\n\n## Global settings\n\n- `foreground_intensity_properties_by_modality`: Intensity statistics of the foreground regions (all labels except \nbackground and ignore label), computed over all training cases. Used by [CT normalization scheme](explanation_normalization.md).\n- `image_reader_writer`: Name of the image reader/writer class that should be used with this dataset. You might want \nto change this if, for example, you would like to run inference with files that have a different file format. The \nclass that is named here must be located in nnunetv2.imageio!\n- `label_manager`: The name of the class that does label handling. Take a look at \nnnunetv2.utilities.label_handling.LabelManager to see what it does. If you decide to change it, place your version \nin nnunetv2.utilities.label_handling!\n- `transpose_forward`: nnU-Net transposes the input data so that the axes with the highest resolution (lowest spacing) \ncome last. This is because the 2D U-Net operates on the trailing dimensions (more efficient slicing due to internal \nmemory layout of arrays). Future work might move this setting to affect only individual configurations. \n- transpose_backward is what numpy.transpose gets as new axis ordering.\n- `transpose_backward`: the axis ordering that inverts \"transpose_forward\"\n- \\[`original_median_shape_after_transp`\\]: just here for your information\n- \\[`original_median_spacing_after_transp`\\]: just here for your information\n- \\[`plans_name`\\]: do not change. Used internally\n- \\[`experiment_planner_used`\\]: just here as metadata so that we know what planner originally generated this file\n- \\[`dataset_name`\\]: do not change. This is the dataset these plans are intended for\n\n## Local settings\nPlans also have a `configurations` key in which the actual configurations are stored. `configurations` are again a \ndictionary, where the keys are the configuration names and the values are the local settings for each configuration.\n\nTo better understand the components describing the network topology in our plans files, please read section 6.2 \nin the [supplementary information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41592-020-01008-z/MediaObjects/41592_2020_1008_MOESM1_ESM.pdf) \n(page 13) of our paper!\n\nLocal settings:\n- `spacing`: the target spacing used in this configuration\n- `patch_size`: the patch size used for training this configuration\n- `data_identifier`: the preprocessed data for this configuration will be saved in\n  nnUNet_preprocessed/DATASET_NAME/_data_identifier_. If you add a new configuration, remember to set a unique\n  data_identifier in order to not create conflicts with other configurations (unless you plan to reuse the data from\n  another configuration, for example as is done in the cascade)\n- `batch_size`: batch size used for training\n- `batch_dice`: whether to use batch dice (pretend all samples in the batch are one image, compute dice loss over that)\nor not (each sample in the batch is a separate image, compute dice loss for each sample and average over samples)\n- `preprocessor_name`: Name of the preprocessor class used for running preprocessing. Class must be located in \nnnunetv2.preprocessing.preprocessors\n- `use_mask_for_norm`: whether to use the nonzero mask for normalization or not (relevant for BraTS and the like, \nprobably False for all other datasets). Interacts with ImageNormalization class\n- `normalization_schemes`: mapping of channel identifier to ImageNormalization class name. ImageNormalization \nclasses must be located in nnunetv2.preprocessing.normalization. Also see [here](explanation_normalization.md)\n- `resampling_fn_data`: name of resampling function to be used for resizing image data. resampling function must be \ncallable(data, current_spacing, new_spacing, **kwargs). It must be located in nnunetv2.preprocessing.resampling\n- `resampling_fn_data_kwargs`: kwargs for resampling_fn_data\n- `resampling_fn_probabilities`: name of resampling function to be used for resizing predicted class probabilities/logits. \nresampling function must be `callable(data: Union[np.ndarray, torch.Tensor], current_spacing, new_spacing, **kwargs)`. It must be located in \nnnunetv2.preprocessing.resampling\n- `resampling_fn_probabilities_kwargs`: kwargs for resampling_fn_probabilities\n- `resampling_fn_seg`: name of resampling function to be used for resizing segmentation maps (integer: 0, 1, 2, 3, etc). \nresampling function must be callable(data, current_spacing, new_spacing, **kwargs). It must be located in \nnnunetv2.preprocessing.resampling\n- `resampling_fn_seg_kwargs`: kwargs for resampling_fn_seg\n- `network_arch_class_name`: UNet class name, can be used to integrate custom dynamic architectures\n- `UNet_base_num_features`: The number of starting features for the UNet architecture. Default is 32. Default: Features\nare doubled with each downsampling \n- `unet_max_num_features`: Maximum number of features (default: capped at 320 for 3D and 512 for 2d). The purpose is to \nprevent parameters from exploding too much. \n- `conv_kernel_sizes`: the convolutional kernel sizes used by nnU-Net in each stage of the encoder. The decoder \n  mirrors the encoder and is therefore not explicitly listed here! The list is as long as `n_conv_per_stage_encoder` has \n  entries\n- `n_conv_per_stage_encoder`: number of convolutions used per stage (=at a feature map resolution in the encoder) in the encoder. \n  Default is 2. The list has as many entries as the encoder has stages\n- `n_conv_per_stage_decoder`: number of convolutions used per stage in the decoder. Also see `n_conv_per_stage_encoder`\n- `num_pool_per_axis`: number of times each of the spatial axes is pooled in the network. Needed to know how to pad \n  image sizes during inference (num_pool = 5 means input must be divisible by 2**5=32)\n- `pool_op_kernel_sizes`: the pooling kernel sizes (and at the same time strides) for each stage of the encoder\n- \\[`median_image_size_in_voxels`\\]: the median size of the images of the training set at the current target spacing. \nDo not modify this as this is not used. It is just here for your information.\n\nSpecial local settings:\n- `inherits_from`: configurations can inherit from each other. This makes it easy to add new configurations that only\ndiffer in a few local settings from another. If using this, remember to set a new `data_identifier` (if needed)!\n- `previous_stage`: if this configuration is part of a cascade, we need to know what the previous stage (for example \nthe low resolution configuration) was. This needs to be specified here.\n- `next_stage`: if this configuration is part of a cascade, we need to know what possible subsequent stages are! This \nis because we need to export predictions in the correct spacing when running the validation. `next_stage` can either \nbe a string or a list of strings\n\n# Examples\n\n## Increasing the batch size for large datasets\nIf your dataset is large the training can benefit from larger batch_sizes. To do this, simply create a new \nconfiguration in the `configurations` dict\n\n    \"configurations\": {\n      \"3d_fullres_bs40\": {\n        \"inherits_from\": \"3d_fullres\",\n        \"batch_size\": 40\n      }\n    }\n\nNo need to change the data_identifier. `3d_fullres_bs40` will just use the preprocessed data from `3d_fullres`.\nNo need to rerun `nnUNetv2_preprocess` because we can use already existing data (if available) from `3d_fullres`.\n\n## Using custom preprocessors\nIf you would like to use a different preprocessor class then this can be specified as follows:\n\n    \"configurations\": {\n      \"3d_fullres_my_preprocesor\": {\n        \"inherits_from\": \"3d_fullres\",\n        \"preprocessor_name\": MY_PREPROCESSOR,\n        \"data_identifier\": \"3d_fullres_my_preprocesor\"\n      }\n    }\n\nYou need to run preprocessing for this new configuration: \n`nnUNetv2_preprocess -d DATASET_ID -c 3d_fullres_my_preprocesor` because it changes the preprocessing. Remember to \nset a unique `data_identifier` whenever you make modifications to the preprocessed data!\n\n## Change target spacing\n\n    \"configurations\": {\n      \"3d_fullres_my_spacing\": {\n        \"inherits_from\": \"3d_fullres\",\n        \"spacing\": [X, Y, Z],\n        \"data_identifier\": \"3d_fullres_my_spacing\"\n      }\n    }\n\nYou need to run preprocessing for this new configuration: \n`nnUNetv2_preprocess -d DATASET_ID -c 3d_fullres_my_spacing` because it changes the preprocessing. Remember to \nset a unique `data_identifier` whenever you make modifications to the preprocessed data!\n\n## Adding a cascade to a dataset where it does not exist\nHippocampus is small. It doesn't have a cascade. It also doesn't really make sense to add a cascade here but hey for \nthe sake of demonstration we can do that.\nWe change the following things here:\n\n- `spacing`: The lowres stage should operate at a lower resolution\n- we modify the `median_image_size_in_voxels` entry as a guide for what original image sizes we deal with\n- we set some patch size that is inspired by `median_image_size_in_voxels`\n- we need to remember that the patch size must be divisible by 2**num_pool in each axis!\n- network parameters such as kernel sizes, pooling operations are changed accordingly\n- we need to specify the name of the next stage\n- we need to add the highres stage\n\nThis is how this would look like (comparisons with 3d_fullres given as reference):\n\n    \"configurations\": {\n      \"3d_lowres\": {\n        \"inherits_from\": \"3d_fullres\",\n        \"data_identifier\": \"3d_lowres\"\n        \"spacing\": [2.0, 2.0, 2.0], # from [1.0, 1.0, 1.0] in 3d_fullres\n        \"median_image_size_in_voxels\": [18, 25, 18], # from [36, 50, 35]\n        \"patch_size\": [20, 28, 20], # from [40, 56, 40]\n        \"n_conv_per_stage_encoder\": [2, 2, 2], # one less entry than 3d_fullres ([2, 2, 2, 2])\n        \"n_conv_per_stage_decoder\": [2, 2], # one less entry than 3d_fullres\n        \"num_pool_per_axis\": [2, 2, 2], # one less pooling than 3d_fullres in each dimension (3d_fullres: [3, 3, 3])\n        \"pool_op_kernel_sizes\": [[1, 1, 1], [2, 2, 2], [2, 2, 2]], # one less [2, 2, 2]\n        \"conv_kernel_sizes\": [[3, 3, 3], [3, 3, 3], [3, 3, 3]], # one less [3, 3, 3]\n        \"next_stage\": \"3d_cascade_fullres\" # name of the next stage in the cascade\n      },\n      \"3d_cascade_fullres\": { # does not need a data_identifier because we can use the data of 3d_fullres\n        \"inherits_from\": \"3d_fullres\",\n        \"previous_stage\": \"3d_lowres\" # name of the previous stage\n      }\n    }\n\nTo better understand the components describing the network topology in our plans files, please read section 6.2 \nin the [supplementary information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41592-020-01008-z/MediaObjects/41592_2020_1008_MOESM1_ESM.pdf) \n(page 13) of our paper!"
  },
  {
    "path": "documentation/extending_nnunet.md",
    "content": "# Extending nnU-Net\nWe hope that the new structure of nnU-Net v2 makes it much more intuitive on how to modify it! We cannot give an \nextensive tutorial on how each and every bit of it can be modified. It is better for you to search for the position \nin the repository where the thing you intend to change is implemented and start working your way through the code from \nthere. Setting breakpoints and debugging into nnU-Net really helps in understanding it and thus will help you make the \nnecessary modifications!\n\nHere are some things you might want to read before you start:\n- Editing nnU-Net configurations through plans files is really powerful now and allows you to change a lot of things regarding \npreprocessing, resampling, network topology etc. Read [this](explanation_plans_files.md)!\n- [Image normalization](explanation_normalization.md) and [i/o formats](dataset_format.md#supported-file-formats) are easy to extend!\n- Manual data splits can be defined as described [here](manual_data_splits.md)\n- You can chain arbitrary configurations together into cascades, see [this again](explanation_plans_files.md)\n- Read about our support for [region-based training](region_based_training.md)\n- If you intend to modify the training procedure (loss, sampling, data augmentation, lr scheduler, etc) then you need \nto implement your own trainer class. Best practice is to create a class that inherits from nnUNetTrainer and \nimplements the necessary changes. Head over to our [trainer classes folder](../nnunetv2/training/nnUNetTrainer) for \ninspiration! There will be similar trainers for what you intend to change and you can take them as a guide. nnUNetTrainer \nare structured similarly to PyTorch lightning trainers, this should also make things easier!\n- Integrating new network architectures can be done in two ways:\n  - Quick and dirty: implement a new nnUNetTrainer class and overwrite its `build_network_architecture` function. \n  Make sure your architecture is compatible with deep supervision (if not, use `nnUNetTrainerNoDeepSupervision`\n  as basis!) and that it can handle the patch sizes that are thrown at it! Your architecture should NOT apply any \n  nonlinearities at the end (softmax, sigmoid etc). nnU-Net does that!   \n  - The 'proper' (but difficult) way: Build a dynamically configurable architecture such as the `PlainConvUNet` class \n  used by default. It needs to have some sort of GPU memory estimation method that can be used to evaluate whether \n  certain patch sizes and \n  topologies fit into a specified GPU memory target. Build a new `ExperimentPlanner` that can configure your new \n  class and communicate with its memory budget estimation. Run `nnUNetv2_plan_and_preprocess` while specifying your \n  custom `ExperimentPlanner` and a custom `plans_name`. Implement a nnUNetTrainer that can use the plans generated by \n  your `ExperimentPlanner` to instantiate the network architecture. Specify your plans and trainer when running `nnUNetv2_train`. \n  It always pays off to first read and understand the corresponding nnU-Net code and use it as a template for your implementation!\n- Remember that multi-GPU training, region-based training, ignore label and cascaded training are now simply integrated \ninto one unified nnUNetTrainer class. No separate classes needed (remember that when implementing your own trainer \nclasses and ensure support for all of these features! Or raise `NotImplementedError`)\n\n[//]: # (- Read about our support for [ignore label]&#40;ignore_label.md&#41; and [region-based training]&#40;region_based_training.md&#41;)\n"
  },
  {
    "path": "documentation/how_to_use_nnunet.md",
    "content": "## **2024-04-18 UPDATE: New residual encoder UNet presets available!**\nThe recommended nnU-Net presets have changed! See [here](resenc_presets.md) how to unlock them!\n\n\n## How to run nnU-Net on a new dataset\n\n\nGiven some dataset, nnU-Net fully automatically configures an entire segmentation pipeline that matches its properties.\nnnU-Net covers the entire pipeline, from preprocessing to model configuration, model training, postprocessing\nall the way to ensembling. After running nnU-Net, the trained model(s) can be applied to the test cases for inference.\n\n### Dataset Format\nnnU-Net expects datasets in a structured format. This format is inspired by the data structure of\nthe [Medical Segmentation Decthlon](http://medicaldecathlon.com/). Please read\n[this](dataset_format.md) for information on how to set up datasets to be compatible with nnU-Net.\n\n**Since version 2 we support multiple image file formats (.nii.gz, .png, .tif, ...)! Read the dataset_format \ndocumentation to learn more!**\n\n**Datasets from nnU-Net v1 can be converted to V2 by running `nnUNetv2_convert_old_nnUNet_dataset INPUT_FOLDER \nOUTPUT_DATASET_NAME`.** Remember that v2 calls datasets DatasetXXX_Name (not Task) where XXX is a 3-digit number.\nPlease provide the **path** to the old task, not just the Task name. nnU-Net V2 doesn't know where v1 tasks were!\n\n### Experiment planning and preprocessing\nGiven a new dataset, nnU-Net will extract a dataset fingerprint (a set of dataset-specific properties such as\nimage sizes, voxel spacings, intensity information etc). This information is used to design three U-Net configurations. \nEach of these pipelines operates on its own preprocessed version of the dataset.\n\nThe easiest way to run fingerprint extraction, experiment planning and preprocessing is to use:\n\n```bash\nnnUNetv2_plan_and_preprocess -d DATASET_ID --verify_dataset_integrity\n```\n\nWhere `DATASET_ID` is the dataset id (duh). We recommend `--verify_dataset_integrity` whenever it's the first time \nyou run this command. This will check for some of the most common error sources! Fingerprint extraction and\npreprocessing show a progress bar by default. Use `--no_pbar` if you want to suppress it, for example in\nnon-interactive cluster environments.\n\nYou can also process several datasets at once by giving `-d 1 2 3 [...]`. If you already know what U-Net configuration \nyou need you can also specify that with `-c 3d_fullres` (make sure to adapt -np in this case!). For more information \nabout all the options available to you please run `nnUNetv2_plan_and_preprocess -h`.\n\nnnUNetv2_plan_and_preprocess will create a new subfolder in your nnUNet_preprocessed folder named after the dataset. \nOnce the command is completed there will be a dataset_fingerprint.json file as well as a nnUNetPlans.json file for you to look at \n(in case you are interested!). There will also be subfolders containing the preprocessed data for your UNet configurations.\n\n[Optional]\nIf you prefer to keep things separate, you can also use `nnUNetv2_extract_fingerprint`, `nnUNetv2_plan_experiment` \nand `nnUNetv2_preprocess` (in that order). `nnUNetv2_extract_fingerprint` and `nnUNetv2_preprocess` also support\n`--no_pbar`.\n\n### Model training\n#### Overview\nYou pick which configurations (2d, 3d_fullres, 3d_lowres, 3d_cascade_fullres) should be trained! If you have no idea \nwhat performs best on your data, just run all of them and let nnU-Net identify the best one. It's up to you!\n\nnnU-Net trains all configurations in a 5-fold cross-validation over the training cases. This is 1) needed so that \nnnU-Net can estimate the performance of each configuration and tell you which one should be used for your \nsegmentation problem and 2) a natural way of obtaining a good model ensemble (average the output of these 5 models \nfor prediction) to boost performance.\n\nYou can influence the splits nnU-Net uses for 5-fold cross-validation (see [here](manual_data_splits.md)). If you \nprefer to train a single model on all training cases, this is also possible (see below).\n\n**Note that not all U-Net configurations are created for all datasets. In datasets with small image sizes, the U-Net\ncascade (and with it the 3d_lowres configuration) is omitted because the patch size of the full resolution U-Net \nalready covers a large part of the input images.**\n\nTraining models is done with the `nnUNetv2_train` command. The general structure of the command is:\n```bash\nnnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD [additional options, see -h]\n```\n\nUNET_CONFIGURATION is a string that identifies the requested U-Net configuration (defaults: 2d, 3d_fullres, 3d_lowres, \n3d_cascade_lowres). DATASET_NAME_OR_ID specifies what dataset should be trained on and FOLD specifies which fold of \nthe 5-fold-cross-validation is trained.\n\nnnU-Net stores a checkpoint every 50 epochs. If you need to continue a previous training, just add a `--c` to the\ntraining command.\n\nIMPORTANT: If you plan to use `nnUNetv2_find_best_configuration` (see below) add the `--npz` flag. This makes \nnnU-Net save the softmax outputs during the final validation. They are needed for that. Exported softmax\npredictions are very large and therefore can take up a lot of disk space, which is why this is not enabled by default.\nIf you ran initially without the `--npz` flag but now require the softmax predictions, simply rerun the validation with:\n```bash\nnnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD --val --npz\n```\n\nYou can specify the device nnU-net should use by using `-device DEVICE`. DEVICE can only be cpu, cuda or mps. If \nyou have multiple GPUs, please select the gpu id using `CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...]` (requires device to be cuda).\n\nSee `nnUNetv2_train -h` for additional options.\n\n### 2D U-Net\nFor FOLD in [0, 1, 2, 3, 4], run:\n```bash\nnnUNetv2_train DATASET_NAME_OR_ID 2d FOLD [--npz]\n```\n\n### 3D full resolution U-Net\nFor FOLD in [0, 1, 2, 3, 4], run:\n```bash\nnnUNetv2_train DATASET_NAME_OR_ID 3d_fullres FOLD [--npz]\n```\n\n### 3D U-Net cascade\n#### 3D low resolution U-Net\nFor FOLD in [0, 1, 2, 3, 4], run:\n```bash\nnnUNetv2_train DATASET_NAME_OR_ID 3d_lowres FOLD [--npz]\n```\n\n#### 3D full resolution U-Net\nFor FOLD in [0, 1, 2, 3, 4], run:\n```bash\nnnUNetv2_train DATASET_NAME_OR_ID 3d_cascade_fullres FOLD [--npz]\n```\n**Note that the 3D full resolution U-Net of the cascade requires the five folds of the low resolution U-Net to be\ncompleted!**\n\nThe trained models will be written to the nnUNet_results folder. Each training obtains an automatically generated\noutput folder name:\n\nnnUNet_results/DatasetXXX_MYNAME/TRAINER_CLASS_NAME__PLANS_NAME__CONFIGURATION/FOLD\n\nFor Dataset002_Heart (from the MSD), for example, this looks like this:\n\n    nnUNet_results/\n    ├── Dataset002_Heart\n        │── nnUNetTrainer__nnUNetPlans__2d\n        │    ├── fold_0\n        │    ├── fold_1\n        │    ├── fold_2\n        │    ├── fold_3\n        │    ├── fold_4\n        │    ├── dataset.json\n        │    ├── dataset_fingerprint.json\n        │    └── plans.json\n        └── nnUNetTrainer__nnUNetPlans__3d_fullres\n             ├── fold_0\n             ├── fold_1\n             ├── fold_2\n             ├── fold_3\n             ├── fold_4\n             ├── dataset.json\n             ├── dataset_fingerprint.json\n             └── plans.json\n\nNote that 3d_lowres and 3d_cascade_fullres do not exist here because this dataset did not trigger the cascade. In each\nmodel training output folder (each of the fold_x folder), the following files will be created:\n- debug.json: Contains a summary of blueprint and inferred parameters used for training this model as well as a \nbunch of additional stuff. Not easy to read, but very useful for debugging ;-)\n- checkpoint_best.pth: checkpoint files of the best model identified during training. Not used right now unless you \nexplicitly tell nnU-Net to use it.\n- checkpoint_final.pth: checkpoint file of the final model (after training has ended). This is what is used for both \nvalidation and inference.\n- network_architecture.pdf (only if hiddenlayer is installed!): a pdf document with a figure of the network architecture in it.\n- progress.png: Shows losses, pseudo dice, learning rate and epoch times ofer the course of the training. At the top is \na plot of the training (blue) and validation (red) loss during training. Also shows an approximation of\n  the dice (green) as well as a moving average of it (dotted green line). This approximation is the average Dice score \n  of the foreground classes. **It needs to be taken with a big (!) \n  grain of salt** because it is computed on randomly drawn patches from the validation\n  data at the end of each epoch, and the aggregation of TP, FP and FN for the Dice computation treats the patches as if\n  they all originate from the same volume ('global Dice'; we do not compute a Dice for each validation case and then\n  average over all cases but pretend that there is only one validation case from which we sample patches). The reason for\n  this is that the 'global Dice' is easy to compute during training and is still quite useful to evaluate whether a model\n  is training at all or not. A proper validation takes way too long to be done each epoch. It is run at the end of the training.\n- validation: in this folder are the predicted validation cases after the training has finished. The summary.json file in here\n  contains the validation metrics (a mean over all cases is provided at the start of the file). If `--npz` was set then \nthe compressed softmax outputs (saved as .npz files) are in here as well. \n\nDuring training it is often useful to watch the progress. We therefore recommend that you have a look at the generated\nprogress.png when running the first training. It will be updated after each epoch.\n\nTraining times largely depend on the GPU. The smallest GPU we recommend for training is the Nvidia RTX 2080ti. With \nthat all network trainings take less than 2 days. Refer to our [benchmarks](benchmarking.md) to see if your system is \nperforming as expected.\n\n### Using multiple GPUs for training\n\nIf multiple GPUs are at your disposal, the best way of using them is to train multiple nnU-Net trainings at once, one \non each GPU. This is because data parallelism never scales perfectly linearly, especially not with small networks such \nas the ones used by nnU-Net.\n\nExample:\n\n```bash\nCUDA_VISIBLE_DEVICES=0 nnUNetv2_train DATASET_NAME_OR_ID 2d 0 [--npz] & # train on GPU 0\nCUDA_VISIBLE_DEVICES=1 nnUNetv2_train DATASET_NAME_OR_ID 2d 1 [--npz] & # train on GPU 1\nCUDA_VISIBLE_DEVICES=2 nnUNetv2_train DATASET_NAME_OR_ID 2d 2 [--npz] & # train on GPU 2\nCUDA_VISIBLE_DEVICES=3 nnUNetv2_train DATASET_NAME_OR_ID 2d 3 [--npz] & # train on GPU 3\nCUDA_VISIBLE_DEVICES=4 nnUNetv2_train DATASET_NAME_OR_ID 2d 4 [--npz] & # train on GPU 4\n...\nwait\n```\n\n**Important: The first time a training is run nnU-Net will extract the preprocessed data into uncompressed numpy \narrays for speed reasons! This operation must be completed before starting more than one training of the same \nconfiguration! Wait with starting subsequent folds until the first training is using the GPU! Depending on the \ndataset size and your System this should only take a couple of minutes at most.**\n\nIf you insist on running DDP multi-GPU training, we got you covered:\n\n`nnUNetv2_train DATASET_NAME_OR_ID 2d 0 [--npz] -num_gpus X`\n\nAgain, note that this will be slower than running separate training on separate GPUs. DDP only makes sense if you have \nmanually interfered with the nnU-Net configuration and are training larger models with larger patch and/or batch sizes!\n\nImportant when using `-num_gpus`:\n1) If you train using, say, 2 GPUs but have more GPUs in the system you need to specify which GPUs should be used via \nCUDA_VISIBLE_DEVICES=0,1 (or whatever your ids are).\n2) You cannot specify more GPUs than you have samples in your minibatches. If the batch size is 2, 2 GPUs is the maximum!\n3) Make sure your batch size is divisible by the numbers of GPUs you use or you will not make good use of your hardware.\n\nIn contrast to the old nnU-Net, DDP is now completely hassle free. Enjoy!\n\n### Automatically determine the best configuration\nOnce the desired configurations were trained (full cross-validation) you can tell nnU-Net to automatically identify \nthe best combination for you:\n\n```commandline\nnnUNetv2_find_best_configuration DATASET_NAME_OR_ID -c CONFIGURATIONS \n```\n\n`CONFIGURATIONS` hereby is the list of configurations you would like to explore. Per default, ensembling is enabled \nmeaning that nnU-Net will generate all possible combinations of ensembles (2 configurations per ensemble). This requires \nthe .npz files containing the predicted probabilities of the validation set to be present (use `nnUNetv2_train` with \n`--npz` flag, see above). You can disable ensembling by setting the `--disable_ensembling` flag.\n\nSee `nnUNetv2_find_best_configuration -h` for more options.\n\nnnUNetv2_find_best_configuration will also automatically determine the postprocessing that should be used. \nPostprocessing in nnU-Net only considers the removal of all but the largest component in the prediction (once for \nforeground vs background and once for each label/region).\n\nOnce completed, the command will print to your console exactly what commands you need to run to make predictions. It \nwill also create two files in the `nnUNet_results/DATASET_NAME` folder for you to inspect: \n- `inference_instructions.txt` again contains the exact commands you need to use for predictions\n- `inference_information.json` can be inspected to see the performance of all configurations and ensembles, as well \nas the effect of the postprocessing plus some debug information. \n\n### Run inference\nRemember that the data located in the input folder must have the file endings as the dataset you trained the model on \nand must adhere to the nnU-Net naming scheme for image files (see [dataset format](dataset_format.md) and \n[inference data format](dataset_format_inference.md)!)\n\n`nnUNetv2_find_best_configuration` (see above) will print a string to the terminal with the inference commands you need to use.\nThe easiest way to run inference is to simply use these commands.\n\nIf you wish to manually specify the configuration(s) used for inference, use the following commands:\n\n#### Run prediction\nFor each of the desired configurations, run:\n```\nnnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -d DATASET_NAME_OR_ID -c CONFIGURATION --save_probabilities\n```\n\nOnly specify `--save_probabilities` if you intend to use ensembling. `--save_probabilities` will make the command save the predicted\nprobabilities alongside of the predicted segmentation masks requiring a lot of disk space.\n\nPlease select a separate `OUTPUT_FOLDER` for each configuration!\n\nNote that per default, inference will be done with all 5 folds from the cross-validation as an ensemble. We very \nstrongly recommend you use all 5 folds. Thus, all 5 folds must have been trained prior to running inference. \n\nIf you wish to make predictions with a single model, train the `all` fold and specify it in `nnUNetv2_predict`\nwith `-f all`\n\n#### Ensembling multiple configurations\nIf you wish to ensemble multiple predictions (typically form different configurations), you can do so with the following command:\n```bash\nnnUNetv2_ensemble -i FOLDER1 FOLDER2 ... -o OUTPUT_FOLDER -np NUM_PROCESSES\n```\n\nYou can specify an arbitrary number of folders, but remember that each folder needs to contain npz files that were\ngenerated by `nnUNetv2_predict`. Again, `nnUNetv2_ensemble -h` will tell you more about additional options.\n\n#### Apply postprocessing\nFinally, apply the previously determined postprocessing to the (ensembled) predictions: \n\n```commandline\nnnUNetv2_apply_postprocessing -i FOLDER_WITH_PREDICTIONS -o OUTPUT_FOLDER --pp_pkl_file POSTPROCESSING_FILE -plans_json PLANS_FILE -dataset_json DATASET_JSON_FILE\n```\n\n`nnUNetv2_find_best_configuration` (or its generated `inference_instructions.txt` file) will tell you where to find \nthe postprocessing file. If not you can just look for it in your results folder (it's creatively named \n`postprocessing.pkl`). If your source folder is from an ensemble, you also need to specify a `-plans_json` file and \na `-dataset_json` file that should be used (for single configuration predictions these are automatically copied \nfrom the respective training). You can pick these files from any of the ensemble members.\n\n\n## How to run inference with pretrained models\nSee [here](run_inference_with_pretrained_models.md)\n\n## How to Deploy and Run Inference with YOUR Pretrained Models\nTo facilitate the use of pretrained models on a different computer for inference purposes, follow these streamlined steps:\n1. Exporting the Model: Utilize the `nnUNetv2_export_model_to_zip` function to package your trained model into a .zip file. This file will contain all necessary model files.\n2. Transferring the Model: Transfer the .zip file to the target computer where inference will be performed.\n3. Importing the Model: On the new PC, use the `nnUNetv2_install_pretrained_model_from_zip` to load the pretrained model from the .zip file.\nPlease note that both computers must have nnU-Net installed along with all its dependencies to ensure compatibility and functionality of the model.\n\n[//]: # (## Examples)\n\n[//]: # ()\n[//]: # (To get you started we compiled two simple to follow examples:)\n\n[//]: # (- run a training with the 3d full resolution U-Net on the Hippocampus dataset. See [here]&#40;documentation/training_example_Hippocampus.md&#41;.)\n\n[//]: # (- run inference with nnU-Net's pretrained models on the Prostate dataset. See [here]&#40;documentation/inference_example_Prostate.md&#41;.)\n\n[//]: # ()\n[//]: # (Usability not good enough? Let us know!)\n"
  },
  {
    "path": "documentation/ignore_label.md",
    "content": "# Ignore Label\n\nThe _ignore label_ can be used to mark regions that should be ignored by nnU-Net. This can be used to \nlearn from images where only sparse annotations are available, for example in the form of scribbles or a limited \namount of annotated slices. Internally, this is accomplished by using partial losses, i.e. losses that are only \ncomputed on annotated pixels while ignoring the rest. Take a look at our \n[`DC_and_BCE_loss` loss](../nnunetv2/training/loss/compound_losses.py) to see how this is done.\nDuring inference (validation and prediction), nnU-Net will always predict dense segmentations. Metric computation in \nvalidation is of course only done on annotated pixels.\n\nUsing sparse annotations can be used to train a model for application to new, unseen images or to autocomplete the \nprovided training cases given the sparse labels. \n\n(See our [paper](https://arxiv.org/abs/2403.12834) for more information)\n\nTypical use-cases for the ignore label are:\n- Save annotation time through sparse annotation schemes\n  - Annotation of all or a subset of slices with scribbles (Scribble Supervision)\n  - Dense annotation of a subset of slices \n  - Dense annotation of chosen patches/cubes within an image\n- Coarsly masking out faulty segmentations in the reference segmentations\n- Masking areas for other reasons\n\nIf you are using nnU-Net's ignore label, please cite the following paper in addition to the original nnU-net paper:\n\n```\nGotkowski, K., Lüth, C., Jäger, P. F., Ziegler, S., Krämer, L., Denner, S., Xiao, S., Disch, N., H., K., & Isensee, F. \n(2024). Embarrassingly Simple Scribble Supervision for 3D Medical Segmentation. ArXiv. /abs/2403.12834\n```\n\n## Usecases\n\n### Scribble Supervision\n\nScribbles are free-form drawings to coarsly annotate an image. As we have demonstrated in our recent [paper](https://arxiv.org/abs/2403.12834), nnU-Net's partial loss implementation enables state-of-the-art learning from partially annotated data and even surpasses many purpose-built methods for learning from scribbles. As a starting point, for each image slice and each class (including background), an interior and a border scribble should be generated:\n\n- Interior Scribble: A scribble placed randomly within the class interior of a class instance\n- Border Scribble: A scribble roughly delineating a small part of the class border of a class instance\n\nAn example of such scribble annotations is depicted in Figure 1 and an animation in Animation 1.\nDepending on the availability of data and their variability it is also possible to only annotated a subset of selected slices.\n\n<p align=\"center\">\n    <img src=\"assets/scribble_example.png\" width=\"1024px\" />\n    <figcaption>Figure 1: Examples of segmentation types with (A) depicting a dense segmentation and (B) a scribble segmentation.</figcaption>\n</figure>\n</p>\n\n<p align=\"center\">\n    <img width=\"512px\" src=\"https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExbmdndHQwMG96M3FqZWtwbHR2enUwZXhwNHVsbndzNmNpZnVlbHJ6OSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/KRJ48evmroDlIgcqcO/giphy.gif\">\n    <img width=\"512px\" src=\"https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExem10Z3ZqZHQ2MWNsMjdibG1zc3M2NzNqbG9mazdudG5raTk4d3h4MSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/ifVxQQfco5ro1gH6bQ/giphy.gif\">\n    <figcaption>Animation 1: Depiction of a dense segmentation and a scribble annotation. Background scribbles have been excluded for better visualization.</figcaption>\n</p>\n\n### Dense annotation of a subset of slices\n\nAnother form of sparse annotation is the dense annotation of a subset of slices. These slices should be selected by the user either randomly, based on visual class variation between slices or in an active learning setting. An example with only 10% of slices annotated is depicted in Figure 2.\n\n<p align=\"center\">\n    <img src=\"assets/amos2022_sparseseg10_2d.png\" width=\"512px\" />\n    <img src=\"assets/amos2022_sparseseg10.png\" width=\"512px\" />\n    <figcaption>Figure 2: Examples of a dense annotation of a subset of slices. The ignored areas are shown in red.</figcaption>\n</figure>\n</p>\n\n\n## Usage within nnU-Net\n\nUsage of the ignore label in nnU-Net is straightforward and only requires the definition of an _ignore_ label in the _dataset.json_.\nThis ignore label MUST be the highest integer label value in the segmentation. Exemplary, given the classes background and two foreground classes, then the ignore label must have the integer 3. The ignore label must be named _ignore_ in the _dataset.json_. Given the BraTS dataset as an example the labels dict of the _dataset.json_ must look like this:\n\n```python\n...\n\"labels\": {\n    \"background\": 0,\n    \"edema\": 1,\n    \"non_enhancing_and_necrosis\": 2,\n    \"enhancing_tumor\": 3,\n    \"ignore\": 4\n},\n...\n```\n\nOf course, the ignore label is compatible with [region-based training](region_based_training.md):\n\n```python\n...\n\"labels\": {\n    \"background\": 0,\n    \"whole_tumor\": (1, 2, 3),\n    \"tumor_core\": (2, 3),\n    \"enhancing_tumor\": 3,  # or (3, )\n    \"ignore\": 4\n},\n\"regions_class_order\": (1, 2, 3),  # don't declare ignore label here! It is not predicted\n...\n```\n\nThen use the dataset as you would any other.\n\nRemember that nnU-Net runs a cross-validation. Thus, it will also evaluate on your partially annotated data. This \nwill of course work! If you wish to compare different sparse annotation strategies (through simulations for example),\nwe recommend evaluating on densely annotated images by running inference and then using `nnUNetv2_evaluate_folder` or \n`nnUNetv2_evaluate_simple`."
  },
  {
    "path": "documentation/installation_instructions.md",
    "content": "# System requirements\n\n## Operating System\nnnU-Net has been tested on Linux (Ubuntu 18.04, 20.04, 22.04; centOS, RHEL), Windows and MacOS! It should work out of the box!\n\n## Hardware requirements\nWe support GPU (recommended), CPU and Apple M1/M2 as devices (currently Apple mps does not implement 3D \nconvolutions, so you might have to use the CPU on those devices).\n\n### Hardware requirements for Training\nWe recommend you use a GPU for training as this will take a really long time on CPU or MPS (Apple M1/M2). \nFor training a GPU with at least 10 GB (popular non-datacenter options are the RTX 2080ti, RTX 3080/3090 or RTX 4080/4090) is \nrequired. We also recommend a strong CPU to go along with the GPU. 6 cores (12 threads) \nare the bare minimum! CPU requirements are mostly related to data augmentation and scale with the number of \ninput channels and target structures. Plus, the faster the GPU, the better the CPU should be!\n\n### Hardware Requirements for inference\nAgain we recommend a GPU to make predictions as this will be substantially faster than the other options. However, \ninference times are typically still manageable on CPU and MPS (Apple M1/M2). If using a GPU, it should have at least \n4 GB of available (unused) VRAM.\n\n### Example hardware configurations\nExample workstation configurations for training:\n- CPU: Ryzen 5800X - 5900X or 7900X would be even better! We have not yet tested Intel Alder/Raptor lake but they will likely work as well.\n- GPU: RTX 3090 or RTX 4090\n- RAM: 64GB\n- Storage: SSD (M.2 PCIe Gen 3 or better!)\n\nExample Server configuration for training:\n- CPU: 2x AMD EPYC7763 for a total of 128C/256T. 16C/GPU are highly recommended for fast GPUs such as the A100!\n- GPU: 8xA100 PCIe (price/performance superior to SXM variant + they use less power)\n- RAM: 1 TB\n- Storage: local SSD storage (PCIe Gen 3 or better) or ultra fast network storage\n\n(nnU-net by default uses one GPU per training. The server configuration can run up to 8 model trainings simultaneously)\n\n### Setting the correct number of Workers for data augmentation (training only)\nNote that you will need to manually set the number of processes nnU-Net uses for data augmentation according to your \nCPU/GPU ratio. For the server above (256 threads for 8 GPUs), a good value would be 24-30. You can do this by \nsetting the `nnUNet_n_proc_DA` environment variable (`export nnUNet_n_proc_DA=XX`). \nRecommended values (assuming a recent CPU with good IPC) are 10-12 for RTX 2080 ti, 12 for a RTX 3090, 16-18 for \nRTX 4090, 28-32 for A100. Optimal values may vary depending on the number of input channels/modalities and number of classes.\n\n# Installation instructions\nWe strongly recommend that you install nnU-Net in a virtual environment! Pip or anaconda are both fine. If you choose to \ncompile PyTorch from source (see below), you will need to use conda instead of pip. \n\nUse a recent version of Python! 3.9 or newer is guaranteed to work!\n\n**nnU-Net v2 can coexist with nnU-Net v1! Both can be installed at the same time.**\n\n1) Install [PyTorch](https://pytorch.org/get-started/locally/) as described on their website (conda/pip). Please \ninstall the latest version with support for your hardware (cuda, mps, cpu).\n**DO NOT JUST `pip install nnunetv2` WITHOUT PROPERLY INSTALLING PYTORCH FIRST**. For maximum speed, consider \n[compiling pytorch yourself](https://github.com/pytorch/pytorch#from-source) (experienced users only!). \n2) Install nnU-Net depending on your use case:\n    1) For use as **standardized baseline**, **out-of-the-box segmentation algorithm** or for running \n     **inference with pretrained models**:\n\n       ```pip install nnunetv2```\n\n    2) For use as integrative **framework** (this will create a copy of the nnU-Net code on your computer so that you\n   can modify it as needed):\n          ```bash\n          git clone https://github.com/MIC-DKFZ/nnUNet.git\n          cd nnUNet\n          pip install -e .\n          ```\n3) nnU-Net needs to know where you intend to save raw data, preprocessed data and trained models. For this you need to\n   set a few environment variables. Please follow the instructions [here](setting_up_paths.md).\n4) (OPTIONAL) Install [hiddenlayer](https://github.com/waleedka/hiddenlayer). hiddenlayer enables nnU-net to generate\n   plots of the network topologies it generates (see [Model training](how_to_use_nnunet.md#model-training)). \nTo install hiddenlayer,\n   run the following command:\n    ```bash\n    pip install --upgrade git+https://github.com/FabianIsensee/hiddenlayer.git\n    ```\n\nInstalling nnU-Net will add several new commands to your terminal. These commands are used to run the entire nnU-Net\npipeline. You can execute them from any location on your system. All nnU-Net commands have the prefix `nnUNetv2_` for\neasy identification.\n\nNote that these commands simply execute python scripts. If you installed nnU-Net in a virtual environment, this\nenvironment must be activated when executing the commands. You can see what scripts/functions are executed by \nchecking the project.scripts in the [pyproject.toml](../pyproject.toml) file.\n\nAll nnU-Net commands have a `-h` option which gives information on how to use them.\n"
  },
  {
    "path": "documentation/manual_data_splits.md",
    "content": "# How to generate custom splits in nnU-Net\n\nSometimes, the default 5-fold cross-validation split by nnU-Net does not fit a project. Maybe you want to run 3-fold \ncross-validation instead? Or maybe your training cases cannot be split randomly and require careful stratification. \nFear not, for nnU-Net has got you covered (it really can do anything <3).\n\nThe splits nnU-Net uses are generated in the `do_split` function of nnUNetTrainer. This function will first look for \nexisting splits, stored as a file, and if no split exists it will create one. So if you wish to influence the split, \nmanually creating a split file that will then be recognized and used is the way to go!\n\nThe split file is located in the `nnUNet_preprocessed/DATASETXXX_NAME` folder. So it is best practice to first \npopulate this folder by running `nnUNetv2_plan_and_preproccess`.\n\nSplits are stored as a .json file. They are a simple python list. The length of that list is the number of splits it \ncontains (so it's 5 in the default nnU-Net). Each list entry is a dictionary with keys 'train' and 'val'. Values are \nagain simply lists with the train identifiers in each set. To illustrate this, I am just messing with the Dataset002 \nfile as an example:\n\n```commandline\nIn [1]: from batchgenerators.utilities.file_and_folder_operations import load_json\n\nIn [2]: splits = load_json('splits_final.json')\n\nIn [3]: len(splits)\nOut[3]: 5\n\nIn [4]: splits[0].keys()\nOut[4]: dict_keys(['train', 'val'])\n\nIn [5]: len(splits[0]['train'])\nOut[5]: 16\n\nIn [6]: len(splits[0]['val'])\nOut[6]: 4\n\nIn [7]: print(splits[0])\n{'train': ['la_003', 'la_004', 'la_005', 'la_009', 'la_010', 'la_011', 'la_014', 'la_017', 'la_018', 'la_019', 'la_020', 'la_022', 'la_023', 'la_026', 'la_029', 'la_030'],\n'val': ['la_007', 'la_016', 'la_021', 'la_024']}\n```\n\nIf you are still not sure what splits are supposed to look like, simply download some reference dataset from the\n[Medical Decathlon](http://medicaldecathlon.com/), start some training (to generate the splits) and manually inspect \nthe .json file with your text editor of choice!\n\nIn order to generate your custom splits, all you need to do is reproduce the data structure explained above and save it as \n`splits_final.json` in the `nnUNet_preprocessed/DATASETXXX_NAME` folder. Then use `nnUNetv2_train` etc. as usual."
  },
  {
    "path": "documentation/pretraining_and_finetuning.md",
    "content": "# Pretraining with nnU-Net\n\n## Intro\n\nSo far nnU-Net only supports supervised pre-training, meaning that you train a regular nnU-Net on some pretraining dataset \nand then use the final network weights as initialization for your target dataset. \n\nAs a reminder, many training hyperparameters such as patch size and network topology differ between datasets as a \nresult of the automated dataset analysis and experiment planning nnU-Net is known for. So, out of the box, it is not \npossible to simply take the network weights from some dataset and then reuse them for another.\n\nConsequently, the plans need to be aligned between the two tasks. In this README we show how this can be achieved and \nhow the resulting weights can then be used for initialization.\n\n### Terminology\n\nThroughout this README we use the following terminology:\n\n- `pretraining dataset` is the dataset you intend to run the pretraining on\n- `finetuning dataset` is the dataset you are interested in; the one you wish to fine tune on\n\n\n## Training on the pretraining dataset\n\nIn order to obtain matching network topologies we need to transfer the plans from one dataset to another. Since we are \nonly interested in the finetuning dataset, we first need to run experiment planning (and preprocessing) for it:\n\n```bash\nnnUNetv2_plan_and_preprocess -d FINETUNING_DATASET\n```\n\nThen we need to extract the dataset fingerprint of the pretraining dataset, if not yet available:\n\n```bash\nnnUNetv2_extract_fingerprint -d PRETRAINING_DATASET\n```\n\nNow we can take the plans from the finetuning dataset and transfer it to the pretraining dataset:\n\n```bash\nnnUNetv2_move_plans_between_datasets -s FINETUNING_DATASET -t PRETRAINING_DATASET -sp FINETUNING_PLANS_IDENTIFIER -tp PRETRAINING_PLANS_IDENTIFIER\n```\n\n`FINETUNING_PLANS_IDENTIFIER` is hereby probably nnUNetPlans unless you changed the experiment planner in \nnnUNetv2_plan_and_preprocess. For `PRETRAINING_PLANS_IDENTIFIER` we recommend you set something custom in order to not \noverwrite default plans.\n\nNote that EVERYTHING is transferred between the datasets. Not just the network topology, batch size and patch size but \nalso the normalization scheme! Therefore, a transfer between datasets that use different normalization schemes may not \nwork well (but it could, depending on the schemes!).\n\nNote on CT normalization: Yes, also the clip values, mean and std are transferred!\n\nNow you can run the preprocessing on the pretraining dataset:\n\n```bash\nnnUNetv2_preprocess -d PRETRAINING_DATASET -plans_name PRETRAINING_PLANS_IDENTIFIER\n```\n\nAnd run the training as usual:\n\n```bash\nnnUNetv2_train PRETRAINING_DATASET CONFIG all -p PRETRAINING_PLANS_IDENTIFIER\n```\n\nNote how we use the 'all' fold to train on all available data. For pretraining it does not make sense to split the data.\n\n## Using pretrained weights\n\nOnce pretraining is completed (or you obtain compatible weights by other means) you can use them to initialize your model:\n\n```bash\nnnUNetv2_train FINETUNING_DATASET CONFIG FOLD -pretrained_weights PATH_TO_CHECKPOINT\n```\n\nSpecify the checkpoint in PATH_TO_CHECKPOINT.\n\nWhen loading pretrained weights, all layers except the segmentation layers will be used! \n\nSo far there are no specific nnUNet trainers for fine tuning, so the current recommendation is to just use \nnnUNetTrainer. You can however easily write your own trainers with learning rate ramp up, fine-tuning of segmentation \nheads or shorter training time."
  },
  {
    "path": "documentation/region_based_training.md",
    "content": "# Region-based training\n\n## What is this about?\nIn some segmentation tasks, most prominently the \n[Brain Tumor Segmentation Challenge](http://braintumorsegmentation.org/), the target areas (based on which the metric \nwill be computed) are different from the labels provided in the training data. This is the case because for some \nclinical applications, it is more relevant to detect the whole tumor, tumor core and enhancing tumor instead of the \nindividual labels (edema, necrosis and non-enhancing tumor, enhancing tumor). \n\n<img src=\"assets/regions_vs_labels.png\" width=\"768px\" />\n\nThe figure shows an example BraTS case along with label-based representation of the task (top) and region-based \nrepresentation (bottom). The challenge evaluation is done on the regions. As we have shown in our \n[BraTS 2018 contribution](https://arxiv.org/abs/1809.10483), directly optimizing those \noverlapping areas over the individual labels yields better scoring models!\n\n## What can nnU-Net do?\nnnU-Net's region-based training allows you to learn areas that are constructed by merging individual labels. For \nsome segmentation tasks this provides a benefit, as this shifts the importance allocated to different labels during training. \nMost prominently, this feature can be used to represent **hierarchical classes**, for example when organs + \nsubstructures are to be segmented. Imagine a liver segmentation problem, where vessels and tumors are also to be \nsegmented. The first target region could thus be the entire liver (including the substructures), while the remaining \ntargets are the individual substructues.\n\nImportant: nnU-Net still requires integer label maps as input and will produce integer label maps as output! \nRegion-based training can be used to learn overlapping labels, but there must be a way to model these overlaps \nfor nnU-Net to work (see below how this is done).\n\n## How do you use it?\n\nWhen declaring the labels in the `dataset.json` file, BraTS would typically look like this:\n\n```python\n...\n\"labels\": {\n    \"background\": 0,\n    \"edema\": 1,\n    \"non_enhancing_and_necrosis\": 2,\n    \"enhancing_tumor\": 3\n},\n...\n```\n(we use different int values than the challenge because nnU-Net needs consecutive integers!)\n\nThis representation corresponds to the upper row in the figure above.\n\nFor region-based training, the labels need to be changed to the following:\n\n```python\n...\n\"labels\": {\n    \"background\": 0,\n    \"whole_tumor\": [1, 2, 3],\n    \"tumor_core\": [2, 3],\n    \"enhancing_tumor\": 3  # or [3]\n},\n\"regions_class_order\": [1, 2, 3],\n...\n```\nThis corresponds to the bottom row in the figure above. Note how an additional entry in the dataset.json is \nrequired: `regions_class_order`. This tells nnU-Net how to convert the region representations back to an integer map. \nIt essentially just tells nnU-Net what labels to place for which region in what order. The length of the \nlist here needs to be the same as the number of regions (excl background). Each element in the list corresponds \nto the label that is placed instead of the region into the final segmentation. Later entries will overwrite earlier ones! \nConcretely, for the example given here, nnU-Net \nwill firstly place the label 1 (edema) where the 'whole_tumor' region was predicted, then place the label 2 \n(non-enhancing tumor and necrosis) where the \"tumor_core\" was predicted and finally place the label 3 in the \npredicted 'enhancing_tumor' area. With each step, part of the previously set pixels \nwill be overwritten with the new label! So when setting your `regions_class_order`, place encompassing regions \n(like whole tumor etc) first, followed by substructures.\n\n**IMPORTANT** Because the conversion back to a segmentation map is sensitive to the order in which the regions are \ndeclared (\"place label X in the first region\") you need to make sure that this order is not perturbed! When \nautomatically generating the dataset.json, make sure the dictionary keys do not get sorted alphabetically! Set \n`sort_keys=False` in `json.dump()`!!!\n\nnnU-Net will perform the evaluation + model selection also on the regions, not the individual labels!\n\nThat's all. Easy, huh?"
  },
  {
    "path": "documentation/resenc_presets.md",
    "content": "# Residual Encoder Presets in nnU-Net\n\nWhen using these presets, please cite our recent paper on the need for rigorous validation in 3D medical image segmentation:\n\n> Isensee, F.<sup>* </sup>, Wald, T.<sup>* </sup>, Ulrich, C.<sup>* </sup>, Baumgartner, M.<sup>* </sup>, Roy, S., Maier-Hein, K.<sup>†</sup>, Jaeger, P.<sup>†</sup> (2024). nnU-Net Revisited: A Call for Rigorous Validation in 3D Medical Image Segmentation. arXiv preprint arXiv:2404.09556.\n\n*: shared first authors\\\n<sup>†</sup>: shared last authors\n\n[PAPER LINK](https://arxiv.org/pdf/2404.09556.pdf)\n\n\nResidual Encoder UNets have been supported by nnU-Net since our participation in KiTS2019, but have flown under the radar.\nThis is bound to change with our new nnUNetResEncUNet presets :raised_hands:! Especially on large datasets such as KiTS2023 and AMOS2022 \nthey offer improved segmentation performance!\n\n|                        | BTCV  | ACDC  | LiTS  | BraTS | KiTS  | AMOS  |  VRAM |  RT | Arch. | nnU |\n|------------------------|-------|-------|-------|-------|-------|-------|-------|-----|-------|-----|\n|                        | n=30  | n=200 | n=131 | n=1251| n=489 | n=360 |       |     |       |     |\n| nnU-Net (org.) [1]     | 83.08 | 91.54 | 80.09 | 91.24 | 86.04 | 88.64 |  7.70 |  9  |  CNN  | Yes |\n| nnU-Net ResEnc M       | 83.31 | 91.99 | 80.75 | 91.26 | 86.79 | 88.77 |  9.10 |  12 |  CNN  | Yes |\n| nnU-Net ResEnc L       | 83.35 | 91.69 | 81.60 | 91.13 | 88.17 | 89.41 | 22.70 |  35 |  CNN  | Yes |\n| nnU-Net ResEnc XL      | 83.28 | 91.48 | 81.19 | 91.18 | 88.67 | 89.68 | 36.60 |  66 |  CNN  | Yes |\n| MedNeXt L k3 [2]       | 84.70 | 92.65 | 82.14 | 91.35 | 88.25 | 89.62 | 17.30 |  68 |  CNN  | Yes |\n| MedNeXt L k5 [2]       | 85.04 | 92.62 | 82.34 | 91.50 | 87.74 | 89.73 | 18.00 | 233 |  CNN  | Yes |\n| STU-Net S [3]          | 82.92 | 91.04 | 78.50 | 90.55 | 84.93 | 88.08 |  5.20 |  10 |  CNN  | Yes |\n| STU-Net B [3]          | 83.05 | 91.30 | 79.19 | 90.85 | 86.32 | 88.46 |  8.80 |  15 |  CNN  | Yes |\n| STU-Net L [3]          | 83.36 | 91.31 | 80.31 | 91.26 | 85.84 | 89.34 | 26.50 |  51 |  CNN  | Yes |\n| SwinUNETR [4]          | 78.89 | 91.29 | 76.50 | 90.68 | 81.27 | 83.81 | 13.10 |  15 |   TF  | Yes |\n| SwinUNETRV2 [5]        | 80.85 | 92.01 | 77.85 | 90.74 | 84.14 | 86.24 | 13.40 |  15 |   TF  | Yes |\n| nnFormer [6]           | 80.86 | 92.40 | 77.40 | 90.22 | 75.85 | 81.55 |  5.70 |  8  |   TF  | Yes |\n| CoTr [7]               | 81.95 | 90.56 | 79.10 | 90.73 | 84.59 | 88.02 |  8.20 |  18 |   TF  | Yes |\n| No-Mamba Base          | 83.69 | 91.89 | 80.57 | 91.26 | 85.98 | 89.04 |  12.0 |  24 |  CNN  | Yes |\n| U-Mamba Bot [8]        | 83.51 | 91.79 | 80.40 | 91.26 | 86.22 | 89.13 | 12.40 |  24 |  Mam  | Yes |\n| U-Mamba Enc [8]        | 82.41 | 91.22 | 80.27 | 90.91 | 86.34 | 88.38 | 24.90 |  47 |  Mam  | Yes |\n| A3DS SegResNet [9,11]  | 80.69 | 90.69 | 79.28 | 90.79 | 81.11 | 87.27 | 20.00 |  22 |  CNN  |  No |\n| A3DS DiNTS [10, 11]    | 78.18 | 82.97 | 69.05 | 87.75 | 65.28 | 82.35 | 29.20 |  16 |  CNN  |  No |\n| A3DS SwinUNETR [4, 11] | 76.54 | 82.68 | 68.59 | 89.90 | 52.82 | 85.05 | 34.50 |  9  |   TF  |  No |\n\nResults taken from our paper (see above), reported values are Dice scores computed over 5-fold cross-validation on each \ndataset. All models trained from scratch.\n\nRT: training run time (measured on 1x Nvidia A100 PCIe 40GB)\\\nVRAM: GPU VRAM used during training, as reported by nvidia-smi\\\nArch.: CNN = convolutional neural network; TF = transformer; Mam = Mamba\\\nnnU: whether the architectrue was integrated and tested with the nnU-Net framework (either by us or the original authors)\n\n## How to use the new presets\n\nWe offer three new presets, each targeted for a different GPU VRAM and compute budget:\n- **nnU-Net ResEnc M**: similar GPU budget to the standard UNet configuration. Best suited for GPUs with 9-11GB VRAM. Training time: ~12h on A100\n- **nnU-Net ResEnc L**: requires a GPU with 24GB VRAM. Training time: ~35h on A100\n- **nnU-Net ResEnc XL**: requires a GPU with 40GB VRAM. Training time: ~66h on A100\n\n### **:point_right: We recommend **nnU-Net ResEnc L** as the new default nnU-Net configuration! :point_left:**\n\nThe new presets are available as follows ((M/L/XL) = pick one!):\n1. Specify the desired configuration when running experiment planning and preprocessing: \n`nnUNetv2_plan_and_preprocess -d DATASET -pl nnUNetPlannerResEnc(M/L/XL)`. These planners use the same preprocessed\ndata folder as the standard 2d and 3d_fullres configurations since the preprocessed data is identical. Only the\n3d_lowres differs and will be saved in a different folder to allow all configurations to coexist! If you are only \nplanning to run 3d_fullres/2d and you already have this data preprocessed, you can just run \n`nnUNetv2_plan_experiment -d DATASET -pl nnUNetPlannerResEnc(M/L/XL)` to avoid preprocessing again! \n2. Now, just specify the correct plans when running `nnUNetv2_train`, `nnUNetv2_predict` etc. The interface is \nconsistent across all nnU-Net commands: `-p nnUNetResEncUNet(M/L/XL)Plans`  \n\nTraining results for the new presets will be stored in a dedicated folder and will not overwrite standard nnU-Net \nresults! So don't be afraid to give it a go!\n\n## Scaling ResEnc nnU-Net beyond the Presets\nThe presets differ from `ResEncUNetPlanner` in two ways:\n- They set new default values for `gpu_memory_target_in_gb` to target the respective VRAM consumptions\n- They remove the batch size cap of 0.05 (= previously one batch could not cover mode pixels than 5% of the entire dataset, not it can be arbitrarily large)\n\nThe presets are merely there to make life easier, and to provide standardized configurations people can benchmark with.\nYou can easily adapt the GPU memory target to match your GPU, and to scale beyond 40GB of GPU memory. \n\nHere is an example for how to scale to 80GB VRAM on Dataset003_Liver:\n\n`nnUNetv2_plan_experiment -d 3 -pl nnUNetPlannerResEncM -gpu_memory_target 80 -overwrite_plans_name nnUNetResEncUNetPlans_80G`\n\nJust use `-p nnUNetResEncUNetPlans_80G` moving forward as outlined above! Running the example above will yield a \nwarning (\"You are running nnUNetPlannerM with a non-standard gpu_memory_target_in_gb\"). This warning can be ignored here.\n**Always change the plans identifier with `-overwrite_plans_name NEW_PLANS_NAME` when messing with the VRAM target in \norder to not overwrite preset plans!**\n\nWhy not use `ResEncUNetPlanner` -> because that one still has the 5% cap in place!\n\n### Scaling to multiple GPUs\nWhen scaling to multiple GPUs, do not just specify the combined amount of VRAM to `nnUNetv2_plan_experiment` as this \nmay result in patch sizes that are too large to be processed by individual GPUs. It is best to let this command run for \nthe VRAM budget of one GPU, and then manually edit the plans file to increase the batch size. You can use [configuration inheritance](explanation_plans_files.md).\nIn the configurations dictionary of the generated plans JSON file, add the following entry:\n\n```json\n        \"3d_fullres_bsXX\": {\n            \"inherits_from\": \"3d_fullres\",\n            \"batch_size\": XX\n        },\n```\nWhere XX is the new batch size. If 3d_fullres has a batch size of 2 for one GPU and you are planning to scale to 8 GPUs, make the new batch size 2x8=16!\nYou can then train the new configuration using nnU-Net's multi-GPU settings:\n\n```bash\nnnUNetv2_train DATASETID 3d_fullres_bsXX FOLD -p nnUNetResEncUNetPlans_80G -num_gpus 8\n```\n\n## Proposing a new segmentation method? Benchmark the right way!\nWhen benchmarking new segmentation methods against nnU-Net, we encourage to benchmark against the residual encoder \nvariants. For a fair comparison, pick the variant that most closely matches the GPU memory and compute \nrequirements of your method!\n\n\n## References\n [1] Isensee, Fabian, et al. \"nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation.\" Nature methods 18.2 (2021): 203-211.\\\n [2] Roy, Saikat, et al. \"Mednext: transformer-driven scaling of convnets for medical image segmentation.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023.\\\n [3] Huang, Ziyan, et al. \"Stu-net: Scalable and transferable medical image segmentation models empowered by large-scale supervised pre-training.\" arXiv preprint arXiv:2304.06716 (2023).\\\n [4] Hatamizadeh, Ali, et al. \"Swin unetr: Swin transformers for semantic segmentation of brain tumors in mri images.\" International MICCAI Brainlesion Workshop. Cham: Springer International Publishing, 2021.\\\n [5] He, Yufan, et al. \"Swinunetr-v2: Stronger swin transformers with stagewise convolutions for 3d medical image segmentation.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023.\\\n [6] Zhou, Hong-Yu, et al. \"nnformer: Interleaved transformer for volumetric segmentation.\" arXiv preprint arXiv:2109.03201 (2021).\\\n [7] Xie, Yutong, et al. \"Cotr: Efficiently bridging cnn and transformer for 3d medical image segmentation.\" Medical Image Computing and Computer Assisted Intervention–MICCAI 2021: 24th International Conference, Strasbourg, France, September 27–October 1, 2021, Proceedings, Part III 24. Springer International Publishing, 2021.\\\n [8] Ma, Jun, Feifei Li, and Bo Wang. \"U-mamba: Enhancing long-range dependency for biomedical image segmentation.\" arXiv preprint arXiv:2401.04722 (2024).\\\n [9] Myronenko, Andriy. \"3D MRI brain tumor segmentation using autoencoder regularization.\" Brainlesion: Glioma, Multiple Sclerosis, Stroke and Traumatic Brain Injuries: 4th International Workshop, BrainLes 2018, Held in Conjunction with MICCAI 2018, Granada, Spain, September 16, 2018, Revised Selected Papers, Part II 4. Springer International Publishing, 2019.\\\n [10] He, Yufan, et al. \"Dints: Differentiable neural network topology search for 3d medical image segmentation.\" Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.\\\n [11] Auto3DSeg, MONAI 1.3.0, [LINK](https://github.com/Project-MONAI/tutorials/tree/ed8854fa19faa49083f48abf25a2c30ab9ac1c6b/auto3dseg)\n\n"
  },
  {
    "path": "documentation/run_inference_with_pretrained_models.md",
    "content": "# How to run inference with pretrained models\n**Important:** Pretrained weights from nnU-Net v1 are NOT compatible with V2. You will need to retrain with the new \nversion. But honestly, you already have a fully trained model with which you can run inference (in v1), so \njust continue using that!\n\nNot yet available for V2 :-(\nIf you wish to run inference with pretrained models, check out the old nnU-Net for now. We are working on this full steam!\n"
  },
  {
    "path": "documentation/set_environment_variables.md",
    "content": "# How to set environment variables\n\nnnU-Net requires some environment variables so that it always knows where the raw data, preprocessed data and trained \nmodels are. Depending on the operating system, these environment variables need to be set in different ways.\n\nVariables can either be set permanently (recommended!) or you can decide to set them every time you call nnU-Net. \n\n# Linux & MacOS\n\n## Permanent\nLocate the `.bashrc` file in your home folder and add the following lines to the bottom:\n\n```bash\nexport nnUNet_raw=\"/media/fabian/nnUNet_raw\"\nexport nnUNet_preprocessed=\"/media/fabian/nnUNet_preprocessed\"\nexport nnUNet_results=\"/media/fabian/nnUNet_results\"\n```\n\n(Of course you need to adapt the paths to the actual folders you intend to use).\nIf you are using a different shell, such as zsh, you will need to find the correct script for it. For zsh this is `.zshrc`.\n\n## Temporary\nJust execute the following lines whenever you run nnU-Net:\n```bash\nexport nnUNet_raw=\"/media/fabian/nnUNet_raw\"\nexport nnUNet_preprocessed=\"/media/fabian/nnUNet_preprocessed\"\nexport nnUNet_results=\"/media/fabian/nnUNet_results\"\n```\n(Of course you need to adapt the paths to the actual folders you intend to use).\n\nImportant: These variables will be deleted if you close your terminal! They will also only apply to the current \nterminal window and DO NOT transfer to other terminals!\n\nAlternatively you can also just prefix them to your nnU-Net commands:\n\n`nnUNet_results=\"/media/fabian/nnUNet_results\" nnUNet_preprocessed=\"/media/fabian/nnUNet_preprocessed\" nnUNetv2_train[...]`\n\n## Verify that environment parameters are set\nYou can always execute `echo ${nnUNet_raw}` etc to print the environment variables. This will return an empty string if \nthey were not set.\n\n# Windows\nUseful links:\n- [https://www3.ntu.edu.sg](https://www3.ntu.edu.sg/home/ehchua/programming/howto/Environment_Variables.html#:~:text=To%20set%20(or%20change)%20a,it%20to%20an%20empty%20string.)\n- [https://phoenixnap.com](https://phoenixnap.com/kb/windows-set-environment-variable)\n\n## Permanent\nSee `Set Environment Variable in Windows via GUI` [here](https://phoenixnap.com/kb/windows-set-environment-variable). \nOr read about setx (command prompt).\n\n## Temporary\nJust execute the following before you run nnU-Net:\n\n(PowerShell)\n```PowerShell\n$Env:nnUNet_raw = \"C:/Users/fabian/nnUNet_raw\"\n$Env:nnUNet_preprocessed = \"C:/Users/fabian/nnUNet_preprocessed\"\n$Env:nnUNet_results = \"C:/Users/fabian/nnUNet_results\"\n```\n\n(Command Prompt)\n```Command Prompt\nset nnUNet_raw=C:/Users/fabian/nnUNet_raw\nset nnUNet_preprocessed=C:/Users/fabian/nnUNet_preprocessed\nset nnUNet_results=C:/Users/fabian/fabian/nnUNet_results\n```\n\n(Of course you need to adapt the paths to the actual folders you intend to use).\n\nImportant: These variables will be deleted if you close your session! They will also only apply to the current \nwindow and DO NOT transfer to other sessions!\n\n## Verify that environment parameters are set\nPrinting in Windows works differently depending on the environment you are in:\n\nPowerShell: `echo $Env:[variable_name]`\n\nCommand Prompt: `echo %variable_name%`\n"
  },
  {
    "path": "documentation/setting_up_paths.md",
    "content": "# Setting up Paths\n\nnnU-Net relies on environment variables to know where raw data, preprocessed data and trained model weights are stored. \nTo use the full functionality of nnU-Net, the following three environment variables must be set:\n\n1) `nnUNet_raw`: This is where you place the raw datasets. This folder will have one subfolder for each dataset names \nDatasetXXX_YYY where XXX is a 3-digit identifier (such as 001, 002, 043, 999, ...) and YYY is the (unique) \ndataset name. The datasets must be in nnU-Net format, see [here](dataset_format.md).\n\n    Example tree structure:\n    ```\n    nnUNet_raw/Dataset001_NAME1\n    ├── dataset.json\n    ├── imagesTr\n    │   ├── ...\n    ├── imagesTs\n    │   ├── ...\n    └── labelsTr\n        ├── ...\n    nnUNet_raw/Dataset002_NAME2\n    ├── dataset.json\n    ├── imagesTr\n    │   ├── ...\n    ├── imagesTs\n    │   ├── ...\n    └── labelsTr\n        ├── ...\n    ```\n\n2) `nnUNet_preprocessed`: This is the folder where the preprocessed data will be saved. The data will also be read from \nthis folder during training. It is important that this folder is located on a drive with low access latency and high \nthroughput (such as a nvme SSD (PCIe gen 3 is sufficient)).\n\n3) `nnUNet_results`: This specifies where nnU-Net will save the model weights. If pretrained models are downloaded, this \nis where it will save them.\n\n### How to set environment variables\n\nSee [here](set_environment_variables.md)."
  },
  {
    "path": "documentation/tldr_migration_guide_from_v1.md",
    "content": "# TLDR Migration Guide from nnU-Net V1\n\n- nnU-Net V2 can be installed simultaneously with V1. They won't get in each other's way\n- The environment variables needed for V2 have slightly different names. Read [this](setting_up_paths.md). \n- nnU-Net V2 datasets are called DatasetXXX_NAME. Not Task.\n- Datasets have the same structure (imagesTr, labelsTr, dataset.json) but we now support more \n[file types](dataset_format.md#supported-file-formats). The dataset.json is simplified. Use `generate_dataset_json` \nfrom nnunetv2.dataset_conversion.generate_dataset_json.py. \n- Careful: labels are now no longer declared as value:name but name:value. This has to do with [hierarchical labels](region_based_training.md). \n- nnU-Net v2 commands start with `nnUNetv2...`. They work mostly (but not entirely) the same. Just use the `-h` option.\n- You can transfer your V1 raw datasets to V2 with `nnUNetv2_convert_old_nnUNet_dataset`. You cannot transfer trained \nmodels. Continue to use the old nnU-Net Version for making inference with those.\n- These are the commands you are most likely to be using (in that order)\n  - `nnUNetv2_plan_and_preprocess`. Example: `nnUNetv2_plan_and_preprocess -d 2`\n  - `nnUNetv2_train`. Example: `nnUNetv2_train 2 3d_fullres 0`\n  - `nnUNetv2_find_best_configuration`. Example: `nnUNetv2_find_best_configuration 2 -c 2d 3d_fullres`. This command\n    will now create a `inference_instructions.txt` file in your `nnUNet_preprocessed/DatasetXXX_NAME/` folder which\n    tells you exactly how to do inference.\n  - `nnUNetv2_predict`. Example: `nnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -c 3d_fullres -d 2`\n  - `nnUNetv2_apply_postprocessing` (see inference_instructions.txt)\n"
  },
  {
    "path": "nnunetv2/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/batch_running/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/batch_running/benchmarking/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py",
    "content": "if __name__ == '__main__':\n    \"\"\"\n    This code probably only works within the DKFZ infrastructure (using LSF). You will need to adapt it to your scheduler! \n    \"\"\"\n    gpu_models = [#'NVIDIAA100_PCIE_40GB', 'NVIDIAGeForceRTX2080Ti', 'NVIDIATITANRTX', 'TeslaV100_SXM2_32GB',\n                  'NVIDIAA100_SXM4_40GB']#, 'TeslaV100_PCIE_32GB']\n    datasets = [2, 3, 4, 5]\n    trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading']\n    plans = ['nnUNetPlans']\n    configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x']\n    num_gpus = 1\n\n    benchmark_configurations = {d: configs for d in datasets}\n\n    exclude_hosts = \"-R \\\"select[hname!='e230-dgxa100-1']'\\\"\"\n    resources = \"-R \\\"tensorcore\\\"\"\n    queue = \"-q gpu\"\n    preamble = \"-L /bin/bash \\\"source ~/load_env_torch210.sh && \"\n    train_command = 'nnUNet_compile=False nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_benchmark nnUNetv2_train'\n\n    folds = (0, )\n\n    use_these_modules = {\n        tr: plans for tr in trainers\n    }\n\n    additional_arguments = f' -num_gpus {num_gpus}'  # ''\n\n    output_file = \"/home/isensee/deleteme.txt\"\n    with open(output_file, 'w') as f:\n        for g in gpu_models:\n            gpu_requirements = f\"-gpu num={num_gpus}:j_exclusive=yes:gmodel={g}\"\n            for tr in use_these_modules.keys():\n                for p in use_these_modules[tr]:\n                    for dataset in benchmark_configurations.keys():\n                        for config in benchmark_configurations[dataset]:\n                            for fl in folds:\n                                command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}'\n                                if additional_arguments is not None and len(additional_arguments) > 0:\n                                    command += f' {additional_arguments}'\n                                f.write(f'{command}\\\"\\n')"
  },
  {
    "path": "nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py",
    "content": "from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\nfrom nnunetv2.paths import nnUNet_results\nfrom nnunetv2.utilities.file_path_utilities import get_output_folder\n\nif __name__ == '__main__':\n    trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading']\n    datasets = [2, 3, 4, 5]\n    plans = ['nnUNetPlans']\n    configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x']\n    output_file = join(nnUNet_results, 'benchmark_results.csv')\n\n    torch_version = '2.1.0.dev20230330'#\"2.0.0\"#\"2.1.0.dev20230328\"  #\"1.11.0a0+gitbc2c6ed\"  #\n    cudnn_version = 8700  # 8302  #\n    num_gpus = 1\n\n    unique_gpus = set()\n\n    # collect results in the most janky way possible. Amazing coding skills!\n    all_results = {}\n    for tr in trainers:\n        all_results[tr] = {}\n        for p in plans:\n            all_results[tr][p] = {}\n            for c in configs:\n                all_results[tr][p][c] = {}\n                for d in datasets:\n                    dataset_name = maybe_convert_to_dataset_name(d)\n                    output_folder = get_output_folder(dataset_name, tr, p, c, fold=0)\n                    expected_benchmark_file = join(output_folder, 'benchmark_result.json')\n                    all_results[tr][p][c][d] = {}\n                    if isfile(expected_benchmark_file):\n                        # filter results for what we want\n                        results = [i for i in load_json(expected_benchmark_file).values()\n                                   if i['num_gpus'] == num_gpus and i['cudnn_version'] == cudnn_version and\n                                   i['torch_version'] == torch_version]\n                        for r in results:\n                            all_results[tr][p][c][d][r['gpu_name']] = r\n                            unique_gpus.add(r['gpu_name'])\n\n    # haha. Fuck this. Collect GPUs in the code above.\n    # unique_gpus = np.unique([i[\"gpu_name\"] for tr in trainers for p in plans for c in configs for d in datasets for i in all_results[tr][p][c][d]])\n\n    unique_gpus = list(unique_gpus)\n    unique_gpus.sort()\n\n    with open(output_file, 'w') as f:\n        f.write('Dataset,Trainer,Plans,Config')\n        for g in unique_gpus:\n            f.write(f\",{g}\")\n        f.write(\"\\n\")\n        for d in datasets:\n            for tr in trainers:\n                for p in plans:\n                    for c in configs:\n                        gpu_results = []\n                        for g in unique_gpus:\n                            if g in all_results[tr][p][c][d].keys():\n                                gpu_results.append(round(all_results[tr][p][c][d][g][\"fastest_epoch\"], ndigits=2))\n                            else:\n                                gpu_results.append(\"MISSING\")\n                        # skip if all are missing\n                        if all([i == 'MISSING' for i in gpu_results]):\n                            continue\n                        f.write(f\"{d},{tr},{p},{c}\")\n                        for g in gpu_results:\n                            f.write(f\",{g}\")\n                        f.write(\"\\n\")\n            f.write(\"\\n\")\n\n"
  },
  {
    "path": "nnunetv2/batch_running/collect_results_custom_Decathlon.py",
    "content": "from typing import Tuple\n\nimport numpy as np\nfrom batchgenerators.utilities.file_and_folder_operations import *\n\nfrom nnunetv2.evaluation.evaluate_predictions import load_summary_json\nfrom nnunetv2.paths import nnUNet_results\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name, convert_dataset_name_to_id\nfrom nnunetv2.utilities.file_path_utilities import get_output_folder\n\n\ndef collect_results(trainers: dict, datasets: List, output_file: str,\n                    configurations=(\"2d\", \"3d_fullres\", \"3d_lowres\", \"3d_cascade_fullres\"),\n                    folds=tuple(np.arange(5))):\n    results_dirs = (nnUNet_results,)\n    datasets_names = [maybe_convert_to_dataset_name(i) for i in datasets]\n    with open(output_file, 'w') as f:\n        for i, d in zip(datasets, datasets_names):\n            for c in configurations:\n                for module in trainers.keys():\n                    for plans in trainers[module]:\n                        for r in results_dirs:\n                            expected_output_folder = get_output_folder(d, module, plans, c)\n                            if isdir(expected_output_folder):\n                                results_folds = []\n                                f.write(f\"{d},{c},{module},{plans},{r}\")\n                                for fl in folds:\n                                    expected_output_folder_fold = get_output_folder(d, module, plans, c, fl)\n                                    expected_summary_file = join(expected_output_folder_fold, \"validation\",\n                                                                 \"summary.json\")\n                                    if not isfile(expected_summary_file):\n                                        print('expected output file not found:', expected_summary_file)\n                                        f.write(\",\")\n                                        results_folds.append(np.nan)\n                                    else:\n                                        foreground_mean = load_summary_json(expected_summary_file)['foreground_mean'][\n                                            'Dice']\n                                        results_folds.append(foreground_mean)\n                                        f.write(f\",{foreground_mean:02.4f}\")\n                                f.write(f\",{np.nanmean(results_folds):02.4f}\\n\")\n\n\ndef summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[str, ...], datasets, trainers):\n    txt = np.loadtxt(input_file, dtype=str, delimiter=',')\n    num_folds = txt.shape[1] - 6\n    valid_configs = {}\n    for d in datasets:\n        if isinstance(d, int):\n            d = maybe_convert_to_dataset_name(d)\n        configs_in_txt = np.unique(txt[:, 1][txt[:, 0] == d])\n        valid_configs[d] = [i for i in configs_in_txt if i in configs]\n    assert max(folds) < num_folds\n\n    with open(output_file, 'w') as f:\n        f.write(\"name\")\n        for d in valid_configs.keys():\n            for c in valid_configs[d]:\n                f.write(\",%d_%s\" % (convert_dataset_name_to_id(d), c[:4]))\n        f.write(',mean\\n')\n        valid_entries = txt[:, 4] == nnUNet_results\n        for t in trainers.keys():\n            trainer_locs = valid_entries & (txt[:, 2] == t)\n            for pl in trainers[t]:\n                f.write(f\"{t}__{pl}\")\n                trainer_plan_locs = trainer_locs & (txt[:, 3] == pl)\n                r = []\n                for d in valid_configs.keys():\n                    trainer_plan_d_locs = trainer_plan_locs & (txt[:, 0] == d)\n                    for v in valid_configs[d]:\n                        trainer_plan_d_config_locs = trainer_plan_d_locs & (txt[:, 1] == v)\n                        if np.any(trainer_plan_d_config_locs):\n                            # we cannot have more than one row\n                            assert np.sum(trainer_plan_d_config_locs) == 1\n\n                            # now check that we have all folds\n                            selected_row = txt[np.argwhere(trainer_plan_d_config_locs)[0,0]]\n\n                            fold_results = selected_row[[i + 5 for i in folds]]\n\n                            if '' in fold_results:\n                                print('missing fold in', t, pl, d, v)\n                                f.write(\",nan\")\n                                r.append(np.nan)\n                            else:\n                                mean_dice = np.mean([float(i) for i in fold_results])\n                                f.write(f\",{mean_dice:02.4f}\")\n                                r.append(mean_dice)\n                        else:\n                            print('missing:', t, pl, d, v)\n                            f.write(\",nan\")\n                            r.append(np.nan)\n                f.write(f\",{np.mean(r):02.4f}\\n\")\n\n\nif __name__ == '__main__':\n    use_these_trainers = {\n        'nnUNetTrainer': ('nnUNetResEncUNetLPlans', ),\n    }\n    all_results_file= join(nnUNet_results, 'customDecResults.csv')\n\n    datasets = [3, 5, 8, 10, 17, 27, 55, 220, 223, 226, 219]\n    # datasets = [3, 4, 5, 8, 10, 17, 27, 55, 220, 223]\n    collect_results(use_these_trainers, datasets, all_results_file)\n\n    # folds = (0, 1, 2, 3, 4)\n    # configs = (\"2d\", )\n    # output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv')\n    # summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)\n\n    folds = (0, )\n    configs = (\"3d_fullres\", )\n    output_file = join(nnUNet_results, 'summary_fold0.csv')\n    summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)\n\n"
  },
  {
    "path": "nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py",
    "content": "from batchgenerators.utilities.file_and_folder_operations import *\n\nfrom nnunetv2.batch_running.collect_results_custom_Decathlon import collect_results, summarize\nfrom nnunetv2.paths import nnUNet_results\n\nif __name__ == '__main__':\n    use_these_trainers = {\n        'nnUNetTrainer': ('nnUNetPlans', ),\n    }\n    all_results_file = join(nnUNet_results, 'hrnet_results.csv')\n    datasets = [2, 3, 4, 17, 20, 24, 27, 38, 55, 64, 82]\n    collect_results(use_these_trainers, datasets, all_results_file)\n\n    folds = (0, )\n    configs = ('2d', )\n    output_file = join(nnUNet_results, 'hrnet_results_summary_fold0.csv')\n    summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)\n\n"
  },
  {
    "path": "nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py",
    "content": "from copy import deepcopy\nimport numpy as np\n\n\ndef merge(dict1, dict2):\n    keys = np.unique(list(dict1.keys()) + list(dict2.keys()))\n    keys = np.unique(keys)\n    res = {}\n    for k in keys:\n        all_configs = []\n        if dict1.get(k) is not None:\n            all_configs += list(dict1[k])\n        if dict2.get(k) is not None:\n            all_configs += list(dict2[k])\n        if len(all_configs) > 0:\n            res[k] = tuple(np.unique(all_configs))\n    return res\n\n\nif __name__ == \"__main__\":\n    # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of\n    # datasets for evaluation and future development\n    configurations_all = {\n        2: (\"3d_fullres\", \"2d\"),\n        3: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        4: (\"2d\", \"3d_fullres\"),\n        5: (\"2d\", \"3d_fullres\"),\n        8: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        10: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        17: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        24: (\"2d\", \"3d_fullres\"),\n        27: (\"2d\", \"3d_fullres\"),\n        38: (\"2d\", \"3d_fullres\"),\n        55: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        137: (\"2d\", \"3d_fullres\"),\n        220: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        # 221: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        223: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        219: (\"2d\", \"3d_fullres\"),\n        226: (\"2d\", \"3d_fullres\"),\n    }\n\n    configurations_3d_fr_only = {\n        i: (\"3d_fullres\", ) for i in configurations_all if \"3d_fullres\" in configurations_all[i]\n    }\n\n    configurations_3d_c_only = {\n        i: (\"3d_cascade_fullres\", ) for i in configurations_all if \"3d_cascade_fullres\" in configurations_all[i]\n    }\n\n    configurations_3d_lr_only = {\n        i: (\"3d_lowres\", ) for i in configurations_all if \"3d_lowres\" in configurations_all[i]\n    }\n\n    configurations_2d_only = {\n        i: (\"2d\", ) for i in configurations_all if \"2d\" in configurations_all[i]\n    }\n\n    num_gpus = 1\n    exclude_hosts = \"-R \\\"select[hname!='e230-dgx2-2']\\\" -R \\\"select[hname!='e230-dgx2-1']\\\"\"\n    resources = \"\"\n    gpu_requirements = f\"-gpu num={num_gpus}:j_exclusive=yes:gmem=23G\"#gmodel=NVIDIAA100_PCIE_40GB\"\n    queue = \"-q gpu-pro\"\n    preamble = \"\\\". /home/isensee/env_loading_scripts/continuous_performance_monitoring/load_env_torch280.sh && \" # -L /bin/bash\n    train_command = 'nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/results_nnUNet_master nnUNetv2_train'\n\n\n    folds = (0, )\n    # use_this = configurations_2d_only\n    use_this = configurations_3d_fr_only\n    # use_this = merge(use_this, configurations_3d_c_only)\n\n    datasets = [3, 5, 8, 10, 17, 27, 55, 219, 220, 223, 226]\n    use_this = {i: use_this[i] for i in datasets}\n\n    use_these_modules = {\n        # 'nnUNetTrainer_newSpatialAug': ('nnUNetPlans',),\n        # 'nnUNetTrainerBN': ('nnUNetPlans',),\n        # 'nnUNetTrainer_newSpatialAug_withElDef_noPref': ('nnUNetPlans',),\n        # 'nnUNetTrainer': ('nnUNetConvNextEncUNetPlans_smallks_and_shallow',),\n        # 'nnUNetTrainer_convnextenc_regularconvblock': ('nnUNetPlans_convnext', 'nnUNetConvNextEncUNetPlans_smallks_and_shallow'),\n        # 'nnUNetTrainer_newSpatialAug_withElDef2': ('nnUNetPlans',),\n        # 'nnUNetTrainer_newSpatialAug_withElDef3': ('nnUNetPlans',),\n        # 'nnUNetTrainer_newSpatialAug_noPref': ('nnUNetPlans',),\n        # 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',),\n        # 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',),\n        # 'nnUNetTrainerAdamW_WDe2': ('nnUNetPlans',),\n        # 'nnUNetTrainerUMambaBot': ('nnUNetPlans',),\n        # 'nnUNetTrainerUMambaEnc': ('nnUNetPlans',),\n        # 'nnUNetTrainer_fasterDA': ('nnUNetPlans', 'nnUNetResEncUNetLPlans'),\n        # 'nnUNetTrainer_noDummy2DDA': ('nnUNetResEncUNetMPlans', ),\n        'nnUNetTrainer': ('nnUNetResEncUNetMPlans', ),\n        'nnUNetTrainerDA5': ('nnUNetResEncUNetMPlans', ),\n        # 'nnUNetTrainer_probabilisticOversampling_033': ('nnUNetResEncUNetMPlans', ),\n        # 'nnUNetTrainer_probabilisticOversampling_010': ('nnUNetResEncUNetMPlans',),\n        # BN\n    }\n\n    additional_arguments = f' -num_gpus {num_gpus} --disable_checkpointing'  # ''\n\n    output_file = \"/home/isensee/deleteme.txt\"\n    with open(output_file, 'w') as f:\n        for tr in use_these_modules.keys():\n            for p in use_these_modules[tr]:\n                for dataset in use_this.keys():\n                    for config in use_this[dataset]:\n                        for fl in folds:\n                            command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}'\n                            if additional_arguments is not None and len(additional_arguments) > 0:\n                                command += f' {additional_arguments}'\n                            f.write(f'{command}\\\"\\n')\n\n"
  },
  {
    "path": "nnunetv2/batch_running/jobs.sh",
    "content": "# lsf22-gpu01\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=3 nnUNetv2_train 3 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 4 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=5 nnUNetv2_train 5 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 8 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=7 nnUNetv2_train 10 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\n\n# lsf22-gpu03\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 17 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=1 nnUNetv2_train 27 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=2 nnUNetv2_train 55 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 220 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\n\n# lsf22-gpu05\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 223 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=1 nnUNetv2_train 3 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=2 nnUNetv2_train 4 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=3 nnUNetv2_train 5 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 8 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 10 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=7 nnUNetv2_train 17 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\n\n# lsf22-gpu06\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 27 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=1 nnUNetv2_train 55 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=2 nnUNetv2_train 220 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=3 nnUNetv2_train 223 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 3 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=5 nnUNetv2_train 4 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 5 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=7 nnUNetv2_train 8 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\n\n# lsf22-gpu07\nscreen -dm bash -c \". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 10 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=1 nnUNetv2_train 17 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=2 nnUNetv2_train 27 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=3 nnUNetv2_train 55 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 220 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=5 nnUNetv2_train 223 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 3 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nscreen -dm bash -c \". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=7 nnUNetv2_train 4 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\n\n\n# launched as jobs\nbsub -R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"  -q gpu -gpu num=1:j_exclusive=yes:gmem=1G \". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 5 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nbsub -R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"  -q gpu -gpu num=1:j_exclusive=yes:gmem=1G \". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 8 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nbsub -R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"  -q gpu -gpu num=1:j_exclusive=yes:gmem=1G \". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 10 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nbsub -R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"  -q gpu -gpu num=1:j_exclusive=yes:gmem=1G \". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 17 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nbsub -R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"  -q gpu -gpu num=1:j_exclusive=yes:gmem=1G \". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 27 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nbsub -R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"  -q gpu -gpu num=1:j_exclusive=yes:gmem=1G \". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 55 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nbsub -R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"  -q gpu -gpu num=1:j_exclusive=yes:gmem=1G \". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 220 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\nbsub -R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"  -q gpu -gpu num=1:j_exclusive=yes:gmem=1G \". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 223 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing\"\n\n"
  },
  {
    "path": "nnunetv2/batch_running/release_trainings/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/batch_running/release_trainings/nnunetv2_v1/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/batch_running/release_trainings/nnunetv2_v1/collect_results.py",
    "content": "from typing import Tuple\n\nimport numpy as np\nfrom batchgenerators.utilities.file_and_folder_operations import *\n\nfrom nnunetv2.evaluation.evaluate_predictions import load_summary_json\nfrom nnunetv2.paths import nnUNet_results\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name, convert_dataset_name_to_id\nfrom nnunetv2.utilities.file_path_utilities import get_output_folder\n\n\ndef collect_results(trainers: dict, datasets: List, output_file: str,\n                    configurations=(\"2d\", \"3d_fullres\", \"3d_lowres\", \"3d_cascade_fullres\"),\n                    folds=tuple(np.arange(5))):\n    results_dirs = (nnUNet_results,)\n    datasets_names = [maybe_convert_to_dataset_name(i) for i in datasets]\n    with open(output_file, 'w') as f:\n        for i, d in zip(datasets, datasets_names):\n            for c in configurations:\n                for module in trainers.keys():\n                    for plans in trainers[module]:\n                        for r in results_dirs:\n                            expected_output_folder = get_output_folder(d, module, plans, c)\n                            if isdir(expected_output_folder):\n                                results_folds = []\n                                f.write(f\"{d},{c},{module},{plans},{r}\")\n                                for fl in folds:\n                                    expected_output_folder_fold = get_output_folder(d, module, plans, c, fl)\n                                    expected_summary_file = join(expected_output_folder_fold, \"validation\",\n                                                                 \"summary.json\")\n                                    if not isfile(expected_summary_file):\n                                        print('expected output file not found:', expected_summary_file)\n                                        f.write(\",\")\n                                        results_folds.append(np.nan)\n                                    else:\n                                        foreground_mean = load_summary_json(expected_summary_file)['foreground_mean'][\n                                            'Dice']\n                                        results_folds.append(foreground_mean)\n                                        f.write(f\",{foreground_mean:02.4f}\")\n                                f.write(f\",{np.nanmean(results_folds):02.4f}\\n\")\n\n\ndef summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[str, ...], datasets, trainers):\n    txt = np.loadtxt(input_file, dtype=str, delimiter=',')\n    num_folds = txt.shape[1] - 6\n    valid_configs = {}\n    for d in datasets:\n        if isinstance(d, int):\n            d = maybe_convert_to_dataset_name(d)\n        configs_in_txt = np.unique(txt[:, 1][txt[:, 0] == d])\n        valid_configs[d] = [i for i in configs_in_txt if i in configs]\n    assert max(folds) < num_folds\n\n    with open(output_file, 'w') as f:\n        f.write(\"name\")\n        for d in valid_configs.keys():\n            for c in valid_configs[d]:\n                f.write(\",%d_%s\" % (convert_dataset_name_to_id(d), c[:4]))\n        f.write(',mean\\n')\n        valid_entries = txt[:, 4] == nnUNet_results\n        for t in trainers.keys():\n            trainer_locs = valid_entries & (txt[:, 2] == t)\n            for pl in trainers[t]:\n                f.write(f\"{t}__{pl}\")\n                trainer_plan_locs = trainer_locs & (txt[:, 3] == pl)\n                r = []\n                for d in valid_configs.keys():\n                    trainer_plan_d_locs = trainer_plan_locs & (txt[:, 0] == d)\n                    for v in valid_configs[d]:\n                        trainer_plan_d_config_locs = trainer_plan_d_locs & (txt[:, 1] == v)\n                        if np.any(trainer_plan_d_config_locs):\n                            # we cannot have more than one row\n                            assert np.sum(trainer_plan_d_config_locs) == 1\n\n                            # now check that we have all folds\n                            selected_row = txt[np.argwhere(trainer_plan_d_config_locs)[0,0]]\n\n                            fold_results = selected_row[[i + 5 for i in folds]]\n\n                            if '' in fold_results:\n                                print('missing fold in', t, pl, d, v)\n                                f.write(\",nan\")\n                                r.append(np.nan)\n                            else:\n                                mean_dice = np.mean([float(i) for i in fold_results])\n                                f.write(f\",{mean_dice:02.4f}\")\n                                r.append(mean_dice)\n                        else:\n                            print('missing:', t, pl, d, v)\n                            f.write(\",nan\")\n                            r.append(np.nan)\n                f.write(f\",{np.mean(r):02.4f}\\n\")\n\n\nif __name__ == '__main__':\n    use_these_trainers = {\n        'nnUNetTrainer': ('nnUNetPlans',),\n        'nnUNetTrainer_v1loss': ('nnUNetPlans',),\n     }\n    all_results_file = join(nnUNet_results, 'customDecResults.csv')\n    datasets = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 24, 27, 35, 38, 48, 55, 64, 82]\n    collect_results(use_these_trainers, datasets, all_results_file)\n\n    folds = (0, 1, 2, 3, 4)\n    configs = (\"3d_fullres\", \"3d_lowres\")\n    output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv')\n    summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)\n\n    folds = (0, )\n    configs = (\"3d_fullres\", \"3d_lowres\")\n    output_file = join(nnUNet_results, 'customDecResults_summaryfold0.csv')\n    summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)\n\n"
  },
  {
    "path": "nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py",
    "content": "from copy import deepcopy\nimport numpy as np\n\n\ndef merge(dict1, dict2):\n    keys = np.unique(list(dict1.keys()) + list(dict2.keys()))\n    keys = np.unique(keys)\n    res = {}\n    for k in keys:\n        all_configs = []\n        if dict1.get(k) is not None:\n            all_configs += list(dict1[k])\n        if dict2.get(k) is not None:\n            all_configs += list(dict2[k])\n        if len(all_configs) > 0:\n            res[k] = tuple(np.unique(all_configs))\n    return res\n\n\nif __name__ == \"__main__\":\n    # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of\n    # datasets for evaluation and future development\n    configurations_all = {\n        # 1: (\"3d_fullres\", \"2d\"),\n        2: (\"3d_fullres\", \"2d\"),\n        # 3: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        # 4: (\"2d\", \"3d_fullres\"),\n        5: (\"2d\", \"3d_fullres\"),\n        # 6: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        # 7: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        # 8: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        # 9: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        # 10: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        # 17: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        20: (\"2d\", \"3d_fullres\"),\n        24: (\"2d\", \"3d_fullres\"),\n        27: (\"2d\", \"3d_fullres\"),\n        35: (\"2d\", \"3d_fullres\"),\n        38: (\"2d\", \"3d_fullres\"),\n        # 55: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        # 64: (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n        # 82: (\"2d\", \"3d_fullres\"),\n        # 83: (\"2d\", \"3d_fullres\"),\n    }\n\n    configurations_3d_fr_only = {\n        i: (\"3d_fullres\", ) for i in configurations_all if \"3d_fullres\" in configurations_all[i]\n    }\n\n    configurations_3d_c_only = {\n        i: (\"3d_cascade_fullres\", ) for i in configurations_all if \"3d_cascade_fullres\" in configurations_all[i]\n    }\n\n    configurations_3d_lr_only = {\n        i: (\"3d_lowres\", ) for i in configurations_all if \"3d_lowres\" in configurations_all[i]\n    }\n\n    configurations_2d_only = {\n        i: (\"2d\", ) for i in configurations_all if \"2d\" in configurations_all[i]\n    }\n\n    num_gpus = 1\n    exclude_hosts = \"-R \\\"select[hname!='e230-dgx2-2']\\\" -R \\\"select[hname!='e230-dgx2-1']\\\"\"\n    resources = \"-R \\\"tensorcore\\\"\"\n    gpu_requirements = f\"-gpu num={num_gpus}:j_exclusive=yes:gmem=1G\"\n    queue = \"-q gpu-lowprio\"\n    preamble = \"-L /bin/bash \\\"source ~/load_env_cluster4.sh && \"\n    train_command = 'nnUNet_keep_files_open=True nnUNet_results=/dkfz/cluster/gpu/data/OE0441/isensee/nnUNet_results_remake_release_normfix nnUNetv2_train'\n\n    folds = (0, 1, 2, 3, 4)\n    # use_this = configurations_2d_only\n    # use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only)\n    # use_this = merge(use_this, configurations_3d_c_only)\n    use_this = configurations_all\n\n    use_these_modules = {\n        'nnUNetTrainer': ('nnUNetPlans',),\n    }\n\n    additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}'  # ''\n\n    output_file = \"/home/isensee/deleteme.txt\"\n    with open(output_file, 'w') as f:\n        for tr in use_these_modules.keys():\n            for p in use_these_modules[tr]:\n                for dataset in use_this.keys():\n                    for config in use_this[dataset]:\n                        for fl in folds:\n                            command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}'\n                            if additional_arguments is not None and len(additional_arguments) > 0:\n                                command += f' {additional_arguments}'\n                            f.write(f'{command}\\\"\\n')\n\n"
  },
  {
    "path": "nnunetv2/configuration.py",
    "content": "import os\n\nfrom nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA\n\ndefault_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc'])\n\nANISO_THRESHOLD = 3  # determines when a sample is considered anisotropic (3 means that the spacing in the low\n# resolution axis must be 3x as large as the next largest spacing)\n\ndefault_n_proc_DA = get_allowed_n_proc_DA()\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset015_018_RibFrac_RibSeg.py",
    "content": "from copy import deepcopy\n\nimport numpy as np\nfrom batchgenerators.utilities.file_and_folder_operations import *\nimport shutil\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw\nimport SimpleITK as sitk\n\nif __name__ == '__main__':\n    \"\"\"\n    Download RibFrac dataset. Links are at https://ribfrac.grand-challenge.org/\n    Download everything. Part1, 2, validation and test\n    Extract EVERYTHING into one folder so that all images and labels are in there. Don't worry they all have unique \n    file names.\n    \n    For RibSeg also download the dataset from https://github.com/M3DV/RibSeg \n    (https://drive.google.com/file/d/1ZZGGrhd0y1fLyOZGo_Y-wlVUP4lkHVgm/view?usp=sharing) and extract in to that same \n    folder (seg only, files end with -rib-seg.nii.gz)\n    \"\"\"\n    # extracted traiing.zip file is here\n    base = '/home/isensee/Downloads/RibFrac_all'\n\n    files = nifti_files(base, join=False)\n    identifiers = np.unique([i.split('-')[0] for i in files])\n\n    # RibFrac\n    target_dataset_id = 15\n    target_dataset_name = f'Dataset{target_dataset_id:03.0f}_RibFrac'\n\n    maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))\n    imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr')\n    imagesTs = join(nnUNet_raw, target_dataset_name, 'imagesTs')\n    labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr')\n    maybe_mkdir_p(imagesTr)\n    maybe_mkdir_p(imagesTs)\n    maybe_mkdir_p(labelsTr)\n\n    n_tr = 0\n    for c in identifiers:\n        print(c)\n        img_file = join(base, c + '-image.nii.gz')\n        seg_file = join(base, c + '-label.nii.gz')\n        if not isfile(seg_file):\n            # test case\n            shutil.copy(img_file, join(imagesTs, c + '_0000.nii.gz'))\n            continue\n        n_tr += 1\n        shutil.copy(img_file, join(imagesTr, c + '_0000.nii.gz'))\n\n        # we must open seg and map -1 to 5\n        seg_itk = sitk.ReadImage(seg_file)\n        seg_npy = sitk.GetArrayFromImage(seg_itk)\n        seg_npy[seg_npy == -1] = 5\n        seg_itk_out = sitk.GetImageFromArray(seg_npy.astype(np.uint8))\n        seg_itk_out.SetSpacing(seg_itk.GetSpacing())\n        seg_itk_out.SetDirection(seg_itk.GetDirection())\n        seg_itk_out.SetOrigin(seg_itk.GetOrigin())\n        sitk.WriteImage(seg_itk_out, join(labelsTr, c + '.nii.gz'))\n\n    # - 0: it is background\n    # - 1: it is a displaced rib fracture\n    # - 2: it is a non-displaced rib fracture\n    # - 3: it is a buckle rib fracture\n    # - 4: it is a segmental rib fracture\n    # - -1: it is a rib fracture,  but we could not define its type due to\n    #   ambiguity, diagnosis difficulty, etc. Ignore it in the\n    #   classification task.\n\n    generate_dataset_json(\n        join(nnUNet_raw, target_dataset_name),\n        channel_names={0: 'CT'},\n        labels = {\n            'background': 0,\n            'fracture': (1, 2, 3, 4, 5),\n            'displaced rib fracture': 1,\n            'non-displaced rib fracture': 2,\n            'buckle rib fracture': 3,\n            'segmental rib fracture': 4,\n        },\n        num_training_cases=n_tr,\n        file_ending='.nii.gz',\n        regions_class_order=(5, 1, 2, 3, 4),\n        dataset_name=target_dataset_name,\n        reference='https://ribfrac.grand-challenge.org/'\n    )\n\n    # RibSeg\n    # overall I am not happy with the GT quality here. But eh what can I do\n\n    target_dataset_name_ribfrac = deepcopy(target_dataset_name)\n    target_dataset_id = 18\n    target_dataset_name = f'Dataset{target_dataset_id:03.0f}_RibSeg'\n\n    maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))\n    imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr')\n    labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr')\n    maybe_mkdir_p(imagesTr)\n    maybe_mkdir_p(labelsTr)\n\n    # the authors have a google shet where they highlight problems with their dataset:\n    # https://docs.google.com/spreadsheets/d/1lz9liWPy8yHybKCdO3BCA9K76QH8a54XduiZS_9fK70/edit?gid=1416415020#gid=1416415020\n    # we exclude the cases marked in red. They have unannotated ribs\n    skip_identifiers = [\n        'RibFrac452',\n        'RibFrac485',\n        'RibFrac490',\n        'RibFrac471',\n        'RibFrac462',\n        'RibFrac487',\n    ]\n\n    n_tr = 0\n    dataset = {}\n    for c in identifiers:\n        if c in skip_identifiers:\n            continue\n        print(c)\n        tr_file = join('$nnUNet_raw', target_dataset_name_ribfrac, 'imagesTr', c + '_0000.nii.gz')\n        ts_file = join('$nnUNet_raw', target_dataset_name_ribfrac, 'imagesTs', c + '_0000.nii.gz')\n        if isfile(os.path.expandvars(tr_file)):\n            img_file = tr_file\n        elif isfile(os.path.expandvars(ts_file)):\n            img_file = ts_file\n        else:\n            raise RuntimeError(f'Missing image file for identifier {identifiers}')\n        seg_file = join(base, c + '-rib-seg.nii.gz')\n        n_tr += 1\n        shutil.copy(seg_file, join(labelsTr, c + '.nii.gz'))\n        dataset[c] = {\n            'images': [img_file],\n            'label': join('labelsTr', c + '.nii.gz')\n        }\n\n    generate_dataset_json(\n        join(nnUNet_raw, target_dataset_name),\n        channel_names={0: 'CT'},\n        labels = {\n            'background': 0,\n            **{'rib%02.0d' % i: i for i in range(1, 25)}\n        },\n        num_training_cases=n_tr,\n        file_ending='.nii.gz',\n        dataset_name=target_dataset_name,\n        reference='https://github.com/M3DV/RibSeg, https://ribfrac.grand-challenge.org/',\n        dataset=dataset\n    )\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset021_CTAAorta.py",
    "content": "from batchgenerators.utilities.file_and_folder_operations import *\nimport shutil\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw\nimport SimpleITK as sitk\n\n\nif __name__ == '__main__':\n    \"\"\"\n    \n    \"\"\"\n    # extracted traiing.zip file is here\n    base = '/home/isensee/Downloads/'\n    target_dataset_id = 21\n    target_dataset_name = f'Dataset{target_dataset_id:03.0f}_CTAAorta'\n\n    maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))\n    imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr')\n    labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr')\n    maybe_mkdir_p(imagesTr)\n    maybe_mkdir_p(labelsTr)\n\n    cases = subfiles(join(base, 'images'), join=False, prefix='subject')\n    for case in cases:\n        outname = case.replace('_CTA.mha', '')\n        im = sitk.ReadImage(join(base, 'images', case))\n        sitk.WriteImage(im, join(imagesTr, outname + '_0000.nii.gz'))\n\n        seg = sitk.ReadImage(join(base, 'masks', case.replace('_CTA.mha', '_label.mha')))\n        sitk.WriteImage(seg, join(labelsTr, outname + '.nii.gz'))\n\n    labels = {\n            \"background\": 0,\n            \"Zone_0\": 1,\n            \"Innominate\": 2,\n            \"Zone_1\": 3,\n            \"Left_Common_Carotid\": 4,\n            \"Zone_2\": 5,\n            \"Left_Subclavian_Artery\": 6,\n            \"Zone_3\": 7,\n            \"Zone_4\": 8,\n            \"Zone_5\": 9,\n            \"Zone_6\": 10,\n            \"Celiac_Artery\": 11,\n            \"Zone_7\": 12,\n            \"SMA\": 13,\n            \"Zone_8\": 14,\n            \"Right_Renal_Artery\": 15,\n            \"Left_Renal_Artery\": 16,\n            \"Zone_9\": 17,\n            \"Zone_10_R_(Right_Common_Iliac_Artery)\": 18,\n            \"Zone_10_L_(Left_Common_Iliac_Artery)\": 19,\n            \"Right_Internal_Iliac_Artery_Dice_Score\": 20,\n            \"Left_Internal_Iliac_Artery_Dice_Score\": 21,\n            \"Zone_11_R_(Right_External_Iliac_Artery)\": 22,\n            \"Zone_11_L_(Left_External_Iliac_Artery)\": 23\n        }\n\n\n    generate_dataset_json(\n        join(nnUNet_raw, target_dataset_name),\n        {0: 'CTA'},\n        labels,\n        len(cases),\n        '.nii.gz',\n        None,\n        target_dataset_name,\n        overwrite_image_reader_writer='NibabelIOWithReorient',\n        reference='https://aortaseg24.grand-challenge.org/',\n        license='see ref'\n    )"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset023_AbdomenAtlas1_1Mini.py",
    "content": "from batchgenerators.utilities.file_and_folder_operations import *\nimport shutil\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw\n\nif __name__ == '__main__':\n    \"\"\"\n    Download the dataset from huggingface:\n    https://huggingface.co/datasets/AbdomenAtlas/_AbdomenAtlas1.1Mini#3--download-the-dataset\n    \n    IMPORTANT\n    cases 5196-9262 currently do not have images, just the segmentation. This seems to be a mistake \n    \"\"\"\n    base = '/home/isensee/Downloads/AbdomenAtlas/uncompressed'\n    target_dataset_id = 23\n    target_dataset_name = f'Dataset{target_dataset_id:03.0f}_AbdomenAtlas1.1Mini'\n\n    cases = subdirs(base, join=False, prefix='BDMAP')\n\n    maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))\n    imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr')\n    labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr')\n    maybe_mkdir_p(imagesTr)\n    maybe_mkdir_p(labelsTr)\n\n    for case in cases:\n        if not isfile(join(base, case, 'ct.nii.gz')):\n            print(f'Skipping case {case} due to missing image')\n            continue\n        shutil.copy(join(base, case, 'ct.nii.gz'), join(imagesTr, case + '_0000.nii.gz'))\n        shutil.copy(join(base, case, 'combined_labels.nii.gz'), join(labelsTr, case + '.nii.gz'))\n\n    class_map = {1: 'aorta', 2: 'gall_bladder', 3: 'kidney_left', 4: 'kidney_right', 5: 'liver',\n                 6: 'pancreas', 7: 'postcava', 8: 'spleen', 9: 'stomach', 10: 'adrenal_gland_left',\n                 11: 'adrenal_gland_right', 12: 'bladder', 13: 'celiac_trunk', 14: 'colon', 15: 'duodenum',\n                 16: 'esophagus', 17: 'femur_left', 18: 'femur_right', 19: 'hepatic_vessel', 20: 'intestine',\n                 21: 'lung_left', 22: 'lung_right', 23: 'portal_vein_and_splenic_vein',\n                 24: 'prostate', 25: 'rectum'}\n    labels = {\n        j: i for i, j in class_map.items()\n    }\n    labels['background'] = 0\n\n    generate_dataset_json(\n        join(nnUNet_raw, target_dataset_name),\n        {0: 'CT'},\n        labels,\n        len(cases),\n        '.nii.gz',\n        None,\n        target_dataset_name,\n        overwrite_image_reader_writer='NibabelIOWithReorient',\n        reference='https://huggingface.co/datasets/AbdomenAtlas/_AbdomenAtlas1.1Mini',\n        license='Creative Commons Attribution Non Commercial Share Alike 4.0; see reference'\n    )"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset027_ACDC.py",
    "content": "import os\nimport shutil\nfrom pathlib import Path\nfrom typing import List\n\nfrom batchgenerators.utilities.file_and_folder_operations import nifti_files, join, maybe_mkdir_p, save_json\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed\nimport numpy as np\n\n\ndef make_out_dirs(dataset_id: int, task_name=\"ACDC\"):\n    dataset_name = f\"Dataset{dataset_id:03d}_{task_name}\"\n\n    out_dir = Path(nnUNet_raw.replace('\"', \"\")) / dataset_name\n    out_train_dir = out_dir / \"imagesTr\"\n    out_labels_dir = out_dir / \"labelsTr\"\n    out_test_dir = out_dir / \"imagesTs\"\n\n    os.makedirs(out_dir, exist_ok=True)\n    os.makedirs(out_train_dir, exist_ok=True)\n    os.makedirs(out_labels_dir, exist_ok=True)\n    os.makedirs(out_test_dir, exist_ok=True)\n\n    return out_dir, out_train_dir, out_labels_dir, out_test_dir\n\n\ndef create_ACDC_split(labelsTr_folder: str, seed: int = 1234) -> List[dict[str, List]]:\n    # labelsTr_folder = '/home/isensee/drives/gpu_data_root/OE0441/isensee/nnUNet_raw/nnUNet_raw_remake/Dataset027_ACDC/labelsTr'\n    nii_files = nifti_files(labelsTr_folder, join=False)\n    patients = np.unique([i[:len('patient000')] for i in nii_files])\n    rs = np.random.RandomState(seed)\n    rs.shuffle(patients)\n    splits = []\n    for fold in range(5):\n        val_patients = patients[fold::5]\n        train_patients = [i for i in patients if i not in val_patients]\n        val_cases = [i[:-7] for i in nii_files for j in val_patients if i.startswith(j)]\n        train_cases = [i[:-7] for i in nii_files for j in train_patients if i.startswith(j)]\n        splits.append({'train': train_cases, 'val': val_cases})\n    return splits\n\n\ndef copy_files(src_data_folder: Path, train_dir: Path, labels_dir: Path, test_dir: Path):\n    \"\"\"Copy files from the ACDC dataset to the nnUNet dataset folder. Returns the number of training cases.\"\"\"\n    patients_train = sorted([f for f in (src_data_folder / \"training\").iterdir() if f.is_dir()])\n    patients_test = sorted([f for f in (src_data_folder / \"testing\").iterdir() if f.is_dir()])\n\n    num_training_cases = 0\n    # Copy training files and corresponding labels.\n    for patient_dir in patients_train:\n        for file in patient_dir.iterdir():\n            if file.suffix == \".gz\" and \"_gt\" not in file.name and \"_4d\" not in file.name:\n                # The stem is 'patient.nii', and the suffix is '.gz'.\n                # We split the stem and append _0000 to the patient part.\n                shutil.copy(file, train_dir / f\"{file.stem.split('.')[0]}_0000.nii.gz\")\n                num_training_cases += 1\n            elif file.suffix == \".gz\" and \"_gt\" in file.name:\n                shutil.copy(file, labels_dir / file.name.replace(\"_gt\", \"\"))\n\n    # Copy test files.\n    for patient_dir in patients_test:\n        for file in patient_dir.iterdir():\n            if file.suffix == \".gz\" and \"_gt\" not in file.name and \"_4d\" not in file.name:\n                shutil.copy(file, test_dir / f\"{file.stem.split('.')[0]}_0000.nii.gz\")\n\n    return num_training_cases\n\n\ndef convert_acdc(src_data_folder: str, dataset_id=27):\n    out_dir, train_dir, labels_dir, test_dir = make_out_dirs(dataset_id=dataset_id)\n    num_training_cases = copy_files(Path(src_data_folder), train_dir, labels_dir, test_dir)\n\n    generate_dataset_json(\n        str(out_dir),\n        channel_names={\n            0: \"cineMRI\",\n        },\n        labels={\n            \"background\": 0,\n            \"RV\": 1,\n            \"MLV\": 2,\n            \"LVC\": 3,\n        },\n        file_ending=\".nii.gz\",\n        num_training_cases=num_training_cases,\n    )\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"-i\",\n        \"--input_folder\",\n        type=str,\n        help=\"The downloaded ACDC dataset dir. Should contain extracted 'training' and 'testing' folders.\",\n    )\n    parser.add_argument(\n        \"-d\", \"--dataset_id\", required=False, type=int, default=27, help=\"nnU-Net Dataset ID, default: 27\"\n    )\n    args = parser.parse_args()\n    print(\"Converting...\")\n    convert_acdc(args.input_folder, args.dataset_id)\n\n    dataset_name = f\"Dataset{args.dataset_id:03d}_{'ACDC'}\"\n    labelsTr = join(nnUNet_raw, dataset_name, 'labelsTr')\n    preprocessed_folder = join(nnUNet_preprocessed, dataset_name)\n    maybe_mkdir_p(preprocessed_folder)\n    split = create_ACDC_split(labelsTr)\n    save_json(split, join(preprocessed_folder, 'splits_final.json'), sort_keys=False)\n\n    print(\"Done!\")\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset042_BraTS18.py",
    "content": "import multiprocessing\nimport shutil\n\nimport SimpleITK as sitk\nimport numpy as np\nfrom tqdm import tqdm\nfrom batchgenerators.utilities.file_and_folder_operations import *\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw\n\n\ndef copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None:\n    # use this for segmentation only!!!\n    # nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3\n    img = sitk.ReadImage(in_file)\n    img_npy = sitk.GetArrayFromImage(img)\n\n    uniques = np.unique(img_npy)\n    for u in uniques:\n        if u not in [0, 1, 2, 4]:\n            raise RuntimeError('unexpected label')\n\n    seg_new = np.zeros_like(img_npy)\n    seg_new[img_npy == 4] = 3\n    seg_new[img_npy == 2] = 1\n    seg_new[img_npy == 1] = 2\n    img_corr = sitk.GetImageFromArray(seg_new)\n    img_corr.CopyInformation(img)\n    sitk.WriteImage(img_corr, out_file)\n\n\ndef convert_labels_back_to_BraTS(seg: np.ndarray):\n    new_seg = np.zeros_like(seg)\n    new_seg[seg == 1] = 2\n    new_seg[seg == 3] = 4\n    new_seg[seg == 2] = 1\n    return new_seg\n\n\ndef load_convert_labels_back_to_BraTS(filename, input_folder, output_folder):\n    a = sitk.ReadImage(join(input_folder, filename))\n    b = sitk.GetArrayFromImage(a)\n    c = convert_labels_back_to_BraTS(b)\n    d = sitk.GetImageFromArray(c)\n    d.CopyInformation(a)\n    sitk.WriteImage(d, join(output_folder, filename))\n\n\ndef convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str,\n                                                                num_processes: int = 12):\n    \"\"\"\n    reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the\n    \"\"\"\n    maybe_mkdir_p(output_folder)\n    nii = subfiles(input_folder, suffix='.nii.gz', join=False)\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as p:\n        p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii)))\n\n\nif __name__ == '__main__':\n    brats_data_dir = ...\n\n    task_id = 42\n    task_name = \"BraTS2018\"\n\n    foldername = \"Dataset%03.0d_%s\" % (task_id, task_name)\n\n    # setting up nnU-Net folders\n    out_base = join(nnUNet_raw, foldername)\n    imagestr = join(out_base, \"imagesTr\")\n    labelstr = join(out_base, \"labelsTr\")\n    maybe_mkdir_p(imagestr)\n    maybe_mkdir_p(labelstr)\n\n    case_ids_hgg = subdirs(join(brats_data_dir, \"HGG\"), prefix='Brats', join=False)\n    case_ids_lgg = subdirs(join(brats_data_dir, \"LGG\"), prefix=\"Brats\", join=False)\n\n    print(\"copying hggs\")\n    for c in tqdm(case_ids_hgg):\n        shutil.copy(join(brats_data_dir, \"HGG\", c, c + \"_t1.nii\"), join(imagestr, c + '_0000.nii'))\n        shutil.copy(join(brats_data_dir, \"HGG\", c, c + \"_t1ce.nii\"), join(imagestr, c + '_0001.nii'))\n        shutil.copy(join(brats_data_dir, \"HGG\", c, c + \"_t2.nii\"), join(imagestr, c + '_0002.nii'))\n        shutil.copy(join(brats_data_dir, \"HGG\", c, c + \"_flair.nii\"), join(imagestr, c + '_0003.nii'))\n\n        copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, \"HGG\", c, c + \"_seg.nii\"),\n                                                             join(labelstr, c + '.nii'))\n    print(\"copying lggs\")\n    for c in tqdm(case_ids_lgg):\n        shutil.copy(join(brats_data_dir, \"LGG\", c, c + \"_t1.nii\"), join(imagestr, c + '_0000.nii'))\n        shutil.copy(join(brats_data_dir, \"LGG\", c, c + \"_t1ce.nii\"), join(imagestr, c + '_0001.nii'))\n        shutil.copy(join(brats_data_dir, \"LGG\", c, c + \"_t2.nii\"), join(imagestr, c + '_0002.nii'))\n        shutil.copy(join(brats_data_dir, \"LGG\", c, c + \"_flair.nii\"), join(imagestr, c + '_0003.nii'))\n\n        copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, \"LGG\", c, c + \"_seg.nii\"),\n                                                             join(labelstr, c + '.nii'))\n\n    generate_dataset_json(out_base,\n                          channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'},\n                          labels={\n                              'background': 0,\n                              'whole tumor': (1, 2, 3),\n                              'tumor core': (2, 3),\n                              'enhancing tumor': (3,)\n                          },\n                          num_training_cases=(len(case_ids_lgg) + len(case_ids_hgg)),\n                          file_ending='.nii',\n                          regions_class_order=(1, 2, 3),\n                          license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',\n                          reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',\n                          dataset_release='1.0')\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset043_BraTS19.py",
    "content": "import multiprocessing\nimport shutil\n\nimport SimpleITK as sitk\nimport numpy as np\nfrom tqdm import tqdm\nfrom batchgenerators.utilities.file_and_folder_operations import *\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw\n\n\ndef copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None:\n    # use this for segmentation only!!!\n    # nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3\n    img = sitk.ReadImage(in_file)\n    img_npy = sitk.GetArrayFromImage(img)\n\n    uniques = np.unique(img_npy)\n    for u in uniques:\n        if u not in [0, 1, 2, 4]:\n            raise RuntimeError('unexpected label')\n\n    seg_new = np.zeros_like(img_npy)\n    seg_new[img_npy == 4] = 3\n    seg_new[img_npy == 2] = 1\n    seg_new[img_npy == 1] = 2\n    img_corr = sitk.GetImageFromArray(seg_new)\n    img_corr.CopyInformation(img)\n    sitk.WriteImage(img_corr, out_file)\n\n\ndef convert_labels_back_to_BraTS(seg: np.ndarray):\n    new_seg = np.zeros_like(seg)\n    new_seg[seg == 1] = 2\n    new_seg[seg == 3] = 4\n    new_seg[seg == 2] = 1\n    return new_seg\n\n\ndef load_convert_labels_back_to_BraTS(filename, input_folder, output_folder):\n    a = sitk.ReadImage(join(input_folder, filename))\n    b = sitk.GetArrayFromImage(a)\n    c = convert_labels_back_to_BraTS(b)\n    d = sitk.GetImageFromArray(c)\n    d.CopyInformation(a)\n    sitk.WriteImage(d, join(output_folder, filename))\n\n\ndef convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str,\n                                                                num_processes: int = 12):\n    \"\"\"\n    reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the\n    \"\"\"\n    maybe_mkdir_p(output_folder)\n    nii = subfiles(input_folder, suffix='.nii.gz', join=False)\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as p:\n        p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii)))\n\n\nif __name__ == '__main__':\n    brats_data_dir = ...\n\n    task_id = 43\n    task_name = \"BraTS2019\"\n\n    foldername = \"Dataset%03.0d_%s\" % (task_id, task_name)\n\n    # setting up nnU-Net folders\n    out_base = join(nnUNet_raw, foldername)\n    imagestr = join(out_base, \"imagesTr\")\n    labelstr = join(out_base, \"labelsTr\")\n    maybe_mkdir_p(imagestr)\n    maybe_mkdir_p(labelstr)\n\n    case_ids_hgg = subdirs(join(brats_data_dir, \"HGG\"), prefix='BraTS', join=False)\n    case_ids_lgg = subdirs(join(brats_data_dir, \"LGG\"), prefix=\"BraTS\", join=False)\n\n    print(\"copying hggs\")\n    for c in tqdm(case_ids_hgg):\n        shutil.copy(join(brats_data_dir, \"HGG\", c, c + \"_t1.nii\"), join(imagestr, c + '_0000.nii'))\n        shutil.copy(join(brats_data_dir, \"HGG\", c, c + \"_t1ce.nii\"), join(imagestr, c + '_0001.nii'))\n        shutil.copy(join(brats_data_dir, \"HGG\", c, c + \"_t2.nii\"), join(imagestr, c + '_0002.nii'))\n        shutil.copy(join(brats_data_dir, \"HGG\", c, c + \"_flair.nii\"), join(imagestr, c + '_0003.nii'))\n\n        copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, \"HGG\", c, c + \"_seg.nii\"),\n                                                             join(labelstr, c + '.nii'))\n    print(\"copying lggs\")\n    for c in tqdm(case_ids_lgg):\n        shutil.copy(join(brats_data_dir, \"LGG\", c, c + \"_t1.nii\"), join(imagestr, c + '_0000.nii'))\n        shutil.copy(join(brats_data_dir, \"LGG\", c, c + \"_t1ce.nii\"), join(imagestr, c + '_0001.nii'))\n        shutil.copy(join(brats_data_dir, \"LGG\", c, c + \"_t2.nii\"), join(imagestr, c + '_0002.nii'))\n        shutil.copy(join(brats_data_dir, \"LGG\", c, c + \"_flair.nii\"), join(imagestr, c + '_0003.nii'))\n\n        copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, \"LGG\", c, c + \"_seg.nii\"),\n                                                             join(labelstr, c + '.nii'))\n\n    generate_dataset_json(out_base,\n                          channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'},\n                          labels={\n                              'background': 0,\n                              'whole tumor': (1, 2, 3),\n                              'tumor core': (2, 3),\n                              'enhancing tumor': (3,)\n                          },\n                          num_training_cases=(len(case_ids_hgg) + len(case_ids_lgg)),\n                          file_ending='.nii',\n                          regions_class_order=(1, 2, 3),\n                          license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',\n                          reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',\n                          dataset_release='1.0')\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset073_Fluo_C3DH_A549_SIM.py",
    "content": "from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed\nimport tifffile\nfrom batchgenerators.utilities.file_and_folder_operations import *\nimport shutil\n\n\nif __name__ == '__main__':\n    \"\"\"\n    This is going to be my test dataset for working with tif as input and output images\n    \n    All we do here is copy the files and rename them. Not file conversions take place \n    \"\"\"\n    dataset_name = 'Dataset073_Fluo_C3DH_A549_SIM'\n\n    imagestr = join(nnUNet_raw, dataset_name, 'imagesTr')\n    imagests = join(nnUNet_raw, dataset_name, 'imagesTs')\n    labelstr = join(nnUNet_raw, dataset_name, 'labelsTr')\n    maybe_mkdir_p(imagestr)\n    maybe_mkdir_p(imagests)\n    maybe_mkdir_p(labelstr)\n\n    # we extract the downloaded train and test datasets to two separate folders and name them Fluo-C3DH-A549-SIM_train\n    # and Fluo-C3DH-A549-SIM_test\n    train_source = '/home/fabian/Downloads/Fluo-C3DH-A549-SIM_train'\n    test_source = '/home/fabian/Downloads/Fluo-C3DH-A549-SIM_test'\n\n    # with the old nnU-Net we had to convert all the files to nifti. This is no longer required. We can just copy the\n    # tif files\n\n    # tif is broken when it comes to spacing. No standards. Grr. So when we use tif nnU-Net expects a separate file\n    # that specifies the spacing. This file needs to exist for EVERY training/test case to allow for different spacings\n    # between files. Important! The spacing must align with the axes.\n    # Here when we do print(tifffile.imread('IMAGE').shape) we get (29, 300, 350). The low resolution axis is the first.\n    # The spacing on the website is griven in the wrong axis order. Great.\n    spacing = (1, 0.126, 0.126)\n\n    # train set\n    for seq in ['01', '02']:\n        images_dir = join(train_source, seq)\n        seg_dir = join(train_source, seq + '_GT', 'SEG')\n        # if we were to be super clean we would go by IDs but here we just trust the files are sorted the correct way.\n        # Simpler filenames in the cell tracking challenge would be soooo nice.\n        images = subfiles(images_dir, suffix='.tif', sort=True, join=False)\n        segs = subfiles(seg_dir, suffix='.tif', sort=True, join=False)\n        for i, (im, se) in enumerate(zip(images, segs)):\n            target_name = f'{seq}_image_{i:03d}'\n            # we still need the '_0000' suffix for images! Otherwise we would not be able to support multiple input\n            # channels distributed over separate files\n            shutil.copy(join(images_dir, im), join(imagestr, target_name + '_0000.tif'))\n            # spacing file!\n            save_json({'spacing': spacing}, join(imagestr, target_name + '.json'))\n            shutil.copy(join(seg_dir, se), join(labelstr, target_name + '.tif'))\n            # spacing file!\n            save_json({'spacing': spacing}, join(labelstr, target_name + '.json'))\n\n    # test set, same a strain just without the segmentations\n    for seq in ['01', '02']:\n        images_dir = join(test_source, seq)\n        images = subfiles(images_dir, suffix='.tif', sort=True, join=False)\n        for i, im in enumerate(images):\n            target_name = f'{seq}_image_{i:03d}'\n            shutil.copy(join(images_dir, im), join(imagests, target_name + '_0000.tif'))\n            # spacing file!\n            save_json({'spacing': spacing}, join(imagests, target_name + '.json'))\n\n    # now we generate the dataset json\n    generate_dataset_json(\n        join(nnUNet_raw, dataset_name),\n        {0: 'fluorescence_microscopy'},\n        {'background': 0, 'cell': 1},\n        60,\n        '.tif'\n    )\n\n    # custom split to ensure we are stratifying properly. This dataset only has 2 folds\n    caseids = [i[:-4] for i in subfiles(labelstr, suffix='.tif', join=False)]\n    splits = []\n    splits.append(\n        {'train': [i for i in caseids if i.startswith('01_')], 'val': [i for i in caseids if i.startswith('02_')]}\n    )\n    splits.append(\n        {'train': [i for i in caseids if i.startswith('02_')], 'val': [i for i in caseids if i.startswith('01_')]}\n    )\n    save_json(splits, join(nnUNet_preprocessed, dataset_name, 'splits_final.json'))"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset114_MNMs.py",
    "content": "import csv\nimport os\nimport random\nfrom pathlib import Path\n\nimport nibabel as nib\nfrom batchgenerators.utilities.file_and_folder_operations import load_json, save_json\n\nfrom nnunetv2.dataset_conversion.Dataset027_ACDC import make_out_dirs\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_preprocessed\n\n\ndef read_csv(csv_file: str):\n    patient_info = {}\n\n    with open(csv_file) as csvfile:\n        reader = csv.reader(csvfile)\n        headers = next(reader)\n        patient_index = headers.index(\"External code\")\n        ed_index = headers.index(\"ED\")\n        es_index = headers.index(\"ES\")\n        vendor_index = headers.index(\"Vendor\")\n\n        for row in reader:\n            patient_info[row[patient_index]] = {\n                \"ed\": int(row[ed_index]),\n                \"es\": int(row[es_index]),\n                \"vendor\": row[vendor_index],\n            }\n\n    return patient_info\n\n\n# ------------------------------------------------------------------------------\n# Conversion to nnUNet format\n# ------------------------------------------------------------------------------\ndef convert_mnms(src_data_folder: Path, csv_file_name: str, dataset_id: int):\n    out_dir, out_train_dir, out_labels_dir, out_test_dir = make_out_dirs(dataset_id, task_name=\"MNMs\")\n    patients_train = [f for f in (src_data_folder / \"Training\" / \"Labeled\").iterdir() if f.is_dir()]\n    patients_test = [f for f in (src_data_folder / \"Testing\").iterdir() if f.is_dir()]\n\n    patient_info = read_csv(str(src_data_folder / csv_file_name))\n\n    save_cardiac_phases(patients_train, patient_info, out_train_dir, out_labels_dir)\n    save_cardiac_phases(patients_test, patient_info, out_test_dir)\n\n    # There are non-orthonormal direction cosines in the test and validation data.\n    # Not sure if the data should be fixed, or we should skip the problematic data.\n    # patients_val = [f for f in (src_data_folder / \"Validation\").iterdir() if f.is_dir()]\n    # save_cardiac_phases(patients_val, patient_info, out_train_dir, out_labels_dir)\n\n    generate_dataset_json(\n        str(out_dir),\n        channel_names={\n            0: \"cineMRI\",\n        },\n        labels={\"background\": 0, \"LVBP\": 1, \"LVM\": 2, \"RV\": 3},\n        file_ending=\".nii.gz\",\n        num_training_cases=len(patients_train) * 2,  # 2 since we have ED and ES for each patient\n    )\n\n\ndef save_cardiac_phases(\n    patients: list[Path], patient_info: dict[str, dict[str, int]], out_dir: Path, labels_dir: Path = None\n):\n    for patient in patients:\n        print(f\"Processing patient: {patient.name}\")\n\n        image = nib.load(patient / f\"{patient.name}_sa.nii.gz\")\n        ed_frame = patient_info[patient.name][\"ed\"]\n        es_frame = patient_info[patient.name][\"es\"]\n\n        save_extracted_nifti_slice(image, ed_frame=ed_frame, es_frame=es_frame, out_dir=out_dir, patient=patient)\n\n        if labels_dir:\n            label = nib.load(patient / f\"{patient.name}_sa_gt.nii.gz\")\n            save_extracted_nifti_slice(label, ed_frame=ed_frame, es_frame=es_frame, out_dir=labels_dir, patient=patient)\n\n\ndef save_extracted_nifti_slice(image, ed_frame: int, es_frame: int, out_dir: Path, patient: Path):\n    # Save only extracted diastole and systole slices from the 4D H x W x D x time volume.\n    image_ed = nib.Nifti1Image(image.dataobj[..., ed_frame], image.affine)\n    image_es = nib.Nifti1Image(image.dataobj[..., es_frame], image.affine)\n\n    # Labels do not have modality identifiers. Labels always end with 'gt'.\n    suffix = \".nii.gz\" if image.get_filename().endswith(\"_gt.nii.gz\") else \"_0000.nii.gz\"\n\n    nib.save(image_ed, str(out_dir / f\"{patient.name}_frame{ed_frame:02d}{suffix}\"))\n    nib.save(image_es, str(out_dir / f\"{patient.name}_frame{es_frame:02d}{suffix}\"))\n\n\n# ------------------------------------------------------------------------------\n# Create custom splits\n# ------------------------------------------------------------------------------\ndef create_custom_splits(src_data_folder: Path, csv_file: str, dataset_id: int, num_val_patients: int = 25):\n    existing_splits = os.path.join(nnUNet_preprocessed, f\"Dataset{dataset_id}_MNMs\", \"splits_final.json\")\n    splits = load_json(existing_splits)\n\n    patients_train = [f.name for f in (src_data_folder / \"Training\" / \"Labeled\").iterdir() if f.is_dir()]\n    # Filter out any patients not in the training set\n    patient_info = {\n        patient: data\n        for patient, data in read_csv(str(src_data_folder / csv_file)).items()\n        if patient in patients_train\n    }\n\n    # Get train and validation patients for both vendors\n    patients_a = [patient for patient, patient_data in patient_info.items() if patient_data[\"vendor\"] == \"A\"]\n    patients_b = [patient for patient, patient_data in patient_info.items() if patient_data[\"vendor\"] == \"B\"]\n    train_a, val_a = get_vendor_split(patients_a, num_val_patients)\n    train_b, val_b = get_vendor_split(patients_b, num_val_patients)\n\n    # Build filenames from corresponding patient frames\n    train_a = [f\"{patient}_frame{patient_info[patient][frame]:02d}\" for patient in train_a for frame in [\"es\", \"ed\"]]\n    train_b = [f\"{patient}_frame{patient_info[patient][frame]:02d}\" for patient in train_b for frame in [\"es\", \"ed\"]]\n    train_a_mix_1, train_a_mix_2 = train_a[: len(train_a) // 2], train_a[len(train_a) // 2 :]\n    train_b_mix_1, train_b_mix_2 = train_b[: len(train_b) // 2], train_b[len(train_b) // 2 :]\n    val_a = [f\"{patient}_frame{patient_info[patient][frame]:02d}\" for patient in val_a for frame in [\"es\", \"ed\"]]\n    val_b = [f\"{patient}_frame{patient_info[patient][frame]:02d}\" for patient in val_b for frame in [\"es\", \"ed\"]]\n\n    for train_set in [train_a, train_b, train_a_mix_1 + train_b_mix_1, train_a_mix_2 + train_b_mix_2]:\n        # For each train set, we evaluate on A, B and (A + B) respectively\n        # See table 3 from the original paper for more details.\n        splits.append({\"train\": train_set, \"val\": val_a})\n        splits.append({\"train\": train_set, \"val\": val_b})\n        splits.append({\"train\": train_set, \"val\": val_a + val_b})\n\n    save_json(splits, existing_splits)\n\n\ndef get_vendor_split(patients: list[str], num_val_patients: int):\n    random.shuffle(patients)\n    total_patients = len(patients)\n    num_training_patients = total_patients - num_val_patients\n    return patients[:num_training_patients], patients[num_training_patients:]\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter):\n        pass\n\n    parser = argparse.ArgumentParser(add_help=False, formatter_class=RawTextArgumentDefaultsHelpFormatter)\n    parser.add_argument(\n        \"-h\",\n        \"--help\",\n        action=\"help\",\n        default=argparse.SUPPRESS,\n        help=\"MNMs conversion utility helper. This script can be used to convert MNMs data into the expected nnUNet \"\n        \"format. It can also be used to create additional custom splits, for explicitly training on combinations \"\n        \"of vendors A and B (see `--custom-splits`).\\n\"\n        \"If you wish to generate the custom splits, run the following pipeline:\\n\\n\"\n        \"(1) Run `Dataset114_MNMs -i <raw_Data_dir>\\n\"\n        \"(2) Run `nnUNetv2_plan_and_preprocess -d 114 --verify_dataset_integrity`\\n\"\n        \"(3) Start training, but stop after initial splits are created: `nnUNetv2_train 114 2d 0`\\n\"\n        \"(4) Re-run `Dataset114_MNMs`, with `-s True`.\\n\"\n        \"(5) Re-run training.\\n\",\n    )\n    parser.add_argument(\n        \"-i\",\n        \"--input_folder\",\n        type=str,\n        default=\"./data/M&Ms/OpenDataset/\",\n        help=\"The downloaded MNMs dataset dir. Should contain a csv file, as well as Training, Validation and Testing \"\n        \"folders.\",\n    )\n    parser.add_argument(\n        \"-c\",\n        \"--csv_file_name\",\n        type=str,\n        default=\"211230_M&Ms_Dataset_information_diagnosis_opendataset.csv\",\n        help=\"The csv file containing the dataset information.\",\n    ),\n    parser.add_argument(\"-d\", \"--dataset_id\", type=int, default=114, help=\"nnUNet Dataset ID.\")\n    parser.add_argument(\n        \"-s\",\n        \"--custom_splits\",\n        type=bool,\n        default=False,\n        help=\"Whether to append custom splits for training and testing on different vendors. If True, will create \"\n        \"splits for training on patients from vendors A, B or a mix of A and B. Splits are tested on a hold-out \"\n        \"validation sets of patients from A, B or A and B combined. See section 2.4 and table 3 from \"\n        \"https://arxiv.org/abs/2011.07592 for more info.\",\n    )\n\n    args = parser.parse_args()\n    args.input_folder = Path(args.input_folder)\n\n    if args.custom_splits:\n        print(\"Appending custom splits...\")\n        create_custom_splits(args.input_folder, args.csv_file_name, args.dataset_id)\n    else:\n        print(\"Converting...\")\n        convert_mnms(args.input_folder, args.csv_file_name, args.dataset_id)\n\n    print(\"Done!\")\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset115_EMIDEC.py",
    "content": "import shutil\nfrom pathlib import Path\n\nfrom nnunetv2.dataset_conversion.Dataset027_ACDC import make_out_dirs\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\n\n\ndef copy_files(src_data_dir: Path, src_test_dir: Path, train_dir: Path, labels_dir: Path, test_dir: Path):\n    \"\"\"Copy files from the EMIDEC dataset to the nnUNet dataset folder. Returns the number of training cases.\"\"\"\n    patients_train = sorted([f for f in src_data_dir.iterdir() if f.is_dir()])\n    patients_test = sorted([f for f in src_test_dir.iterdir() if f.is_dir()])\n\n    # Copy training files and corresponding labels.\n    for patient in patients_train:\n        train_file = patient / \"Images\" / f\"{patient.name}.nii.gz\"\n        label_file = patient / \"Contours\" / f\"{patient.name}.nii.gz\"\n        shutil.copy(train_file, train_dir / f\"{train_file.stem.split('.')[0]}_0000.nii.gz\")\n        shutil.copy(label_file, labels_dir)\n\n    # Copy test files.\n    for patient in patients_test:\n        test_file = patient / \"Images\" / f\"{patient.name}.nii.gz\"\n        shutil.copy(test_file, test_dir / f\"{test_file.stem.split('.')[0]}_0000.nii.gz\")\n\n    return len(patients_train)\n\n\ndef convert_emidec(src_data_dir: str, src_test_dir: str, dataset_id=27):\n    out_dir, train_dir, labels_dir, test_dir = make_out_dirs(dataset_id=dataset_id, task_name=\"EMIDEC\")\n    num_training_cases = copy_files(Path(src_data_dir), Path(src_test_dir), train_dir, labels_dir, test_dir)\n\n    generate_dataset_json(\n        str(out_dir),\n        channel_names={\n            0: \"cineMRI\",\n        },\n        labels={\n            \"background\": 0,\n            \"cavity\": 1,\n            \"normal_myocardium\": 2,\n            \"myocardial_infarction\": 3,\n            \"no_reflow\": 4,\n        },\n        file_ending=\".nii.gz\",\n        num_training_cases=num_training_cases,\n    )\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-i\", \"--input_dir\", type=str, help=\"The EMIDEC dataset directory.\")\n    parser.add_argument(\"-t\", \"--test_dir\", type=str, help=\"The EMIDEC test set directory.\")\n    parser.add_argument(\n        \"-d\", \"--dataset_id\", required=False, type=int, default=115, help=\"nnU-Net Dataset ID, default: 115\"\n    )\n    args = parser.parse_args()\n    print(\"Converting...\")\n    convert_emidec(args.input_dir, args.test_dir, args.dataset_id)\n    print(\"Done!\")\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset119_ToothFairy2_All.py",
    "content": "from typing import Dict, Any\nimport os\nfrom os.path import join\nimport json\nimport random\nimport multiprocessing\n\nimport SimpleITK as sitk\nimport numpy as np\nfrom tqdm import tqdm\n\n\ndef mapping_DS119() -> Dict[int, int]:\n    \"\"\"Remove all NA Classes and make Class IDs continuous\"\"\"\n    mapping = {}\n    mapping.update({i: i for i in range(1, 19)})  # [1-10]->[1-10] | [11-18]->[11-18]\n    mapping.update({i: i - 2 for i in range(21, 29)})  # [21-28]->[19-26]\n    mapping.update({i: i - 4 for i in range(31, 39)})  # [31-38]->[27-34]\n    mapping.update({i: i - 6 for i in range(41, 49)})  # [41-48]->[35-42]\n    return mapping\n\n\ndef mapping_DS120() -> Dict[int, int]:\n    \"\"\"Remove Only Keep Teeth and Jaw Classes\"\"\"\n    mapping = {}\n    mapping.update({i: i for i in range(1, 3)})  # [0-2] -> [0-2]\n    mapping.update({i: i - 8 for i in range(11, 19)})  # [11-18]->[3-10]\n    mapping.update({i: i - 10 for i in range(21, 29)})  # [21-28]->[11-18]\n    mapping.update({i: i - 12 for i in range(31, 39)})  # [31-38]->[19-26]\n    mapping.update({i: i - 14 for i in range(41, 49)})  # [41-48]->[27-34]\n    return mapping\n\n\ndef mapping_DS121() -> Dict[int, int]:\n    \"\"\"Remove Only Keep Teeth and Jaw Classes\"\"\"\n    mapping = {}\n    mapping.update({i: i - 10 for i in range(11, 19)})  # [11-18]->[3-8]\n    mapping.update({i: i - 12 for i in range(21, 29)})  # [21-28]->[11-16]\n    mapping.update({i: i - 14 for i in range(31, 39)})  # [31-38]->[19-24]\n    mapping.update({i: i - 16 for i in range(41, 49)})  # [41-48]->[27-32]\n    return mapping\n\n\ndef load_json(json_file: str) -> Any:\n    with open(json_file, \"r\") as f:\n        data = json.load(f)\n    return data\n\n\ndef write_json(json_file: str, data: Any, indent: int = 4) -> None:\n    with open(json_file, \"w\") as f:\n        json.dump(data, f, indent=indent)\n\n\ndef image_to_nifi(input_path: str, output_path: str) -> None:\n    image_sitk = sitk.ReadImage(input_path)\n    sitk.WriteImage(image_sitk, output_path)\n\n\ndef label_mapping(input_path: str, output_path: str, mapping: Dict[int, int] = None) -> None:\n\n    label_sitk = sitk.ReadImage(input_path)\n    if mapping is not None:\n        label_np = sitk.GetArrayFromImage(label_sitk)\n\n        label_np_new = np.zeros_like(label_np, dtype=np.uint8)\n        for org_id, new_id in mapping.items():\n            label_np_new[label_np == org_id] = new_id\n\n        label_sitk_new = sitk.GetImageFromArray(label_np_new)\n        label_sitk_new.CopyInformation(label_sitk)\n        sitk.WriteImage(label_sitk_new, output_path)\n    else:\n        sitk.WriteImage(label_sitk, output_path)\n\n\ndef process_images(files: str, img_dir_in: str, img_dir_out: str, n_processes: int = 12):\n\n    os.makedirs(img_dir_out, exist_ok=True)\n\n    iterable = [\n        {\n            \"input_path\": join(img_dir_in, file),\n            \"output_path\": join(img_dir_out, file.replace(\".mha\", \".nii.gz\")),\n        }\n        for file in files\n    ]\n    with multiprocessing.Pool(processes=n_processes) as pool:\n        jobs = [pool.apply_async(image_to_nifi, kwds={**args}) for args in iterable]\n        _ = [job.get() for job in tqdm(jobs, desc=\"Process Images\")]\n\n\ndef process_labels(\n    files: str, lbl_dir_in: str, lbl_dir_out: str, mapping: Dict[int, int], n_processes: int = 12\n) -> None:\n\n    os.makedirs(lbl_dir_out, exist_ok=True)\n\n    iterable = [\n        {\n            \"input_path\": join(lbl_dir_in, file),\n            \"output_path\": join(lbl_dir_out, file.replace(\".mha\", \".nii.gz\")),\n            \"mapping\": mapping,\n        }\n        for file in files\n    ]\n    with multiprocessing.Pool(processes=n_processes) as pool:\n        jobs = [pool.apply_async(label_mapping, kwds={**args}) for args in iterable]\n        _ = [job.get() for job in tqdm(jobs, desc=\"Process Labels...\")]\n\n\ndef process_ds(\n    root: str, input_ds: str, output_ds: str, mapping: dict, image_link: str = None\n) -> None:\n    os.makedirs(join(root, output_ds), exist_ok=True)\n    os.makedirs(join(root, output_ds, \"labelsTr\"), exist_ok=True)\n    # --- Handle Labels --- #\n    lbl_files = os.listdir(join(root, input_ds, \"labelsTr\"))\n    lbl_dir_in = join(root, input_ds, \"labelsTr\")\n    lbl_dir_out = join(root, output_ds, \"labelsTr\")\n\n    process_labels(lbl_files, lbl_dir_in, lbl_dir_out, mapping, n_processes=12)\n\n    # --- Handle Images --- #\n    img_files = os.listdir(join(root, input_ds, \"imagesTr\"))\n    dataset = {}\n    if image_link is None:\n        img_dir_in = join(root, input_ds, \"imagesTr\")\n        img_dir_out = join(root, output_ds, \"imagesTr\")\n\n        process_images(img_files, img_dir_in, img_dir_out, n_processes=12)\n    else:\n        base_name = [file.replace(\"_0000.mha\", \"\") for file in img_files]\n        for name in base_name:\n            dataset[name] = {\n                \"images\": [join(\"..\", image_link, \"imagesTr\", name + \"_0000.nii.gz\")],\n                \"label\": join(\"labelsTr\", name + \".nii.gz\"),\n            }\n\n    # --- Generate dataset.json --- #\n    dataset_json = load_json(join(root, input_ds, \"dataset.json\"))\n    dataset_json[\"file_ending\"] = \".nii.gz\"\n    dataset_json[\"name\"] = output_ds\n    dataset_json[\"numTraining\"] = len(lbl_files)\n    if dataset != {}:\n        dataset_json[\"dataset\"] = dataset\n\n    label_dict = dataset_json[\"labels\"]\n    label_dict_new = {\"background\": 0}\n    for k, v in label_dict.items():\n        if v in mapping.keys():\n            label_dict_new[k] = mapping[v]\n    dataset_json[\"labels\"] = label_dict_new\n    write_json(join(root, output_ds, \"dataset.json\"), dataset_json)\n\n    # --- Generate splits_final.json --- #\n    img_names = [file.replace(\"_0000.mha\", \"\") for file in img_files]\n\n    random_seed = 42\n    random.seed(random_seed)\n    random.shuffle(img_names)\n\n    split_index = int(len(img_names) * 0.7)  # 70:30 split\n    train_files = img_names[:split_index]\n    val_files = img_names[split_index:]\n    train_files.sort()\n    val_files.sort()\n\n    split = [{\"train\": train_files, \"val\": val_files}]\n    write_json(join(root, output_ds, \"splits_final.json\"), split)\n\n\nif __name__ == \"__main__\":\n    # Different nnUNet Datasets\n    # Dataset 112: Raw\n    # Dataset 119: Replace NaN classes\n    # Dataset 120: Only Teeth + Jaw Classes\n    # Dataset 121: Only Teeth Classes\n\n    root = \"/media/l727r/data/Teeth_Data/ToothFairy2_Dataset\"\n\n    process_ds(root, \"Dataset112_ToothFairy2\", \"Dataset119_ToothFairy2_All\", mapping_DS119(), None)\n    # process_ds(\n    #     root,\n    #     \"Dataset112_ToothFairy2\",\n    #     \"Dataset120_ToothFairy2_JawTeeth\",\n    #     mapping_DS120(),\n    #     \"Dataset119_ToothFairy2_All\",\n    # )\n    # process_ds(\n    #     root,\n    #     \"Dataset112_ToothFairy2\",\n    #     \"Dataset121_ToothFairy2_Teeth\",\n    #     mapping_DS121(),\n    #     \"Dataset119_ToothFairy2_All\",\n    # )\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py",
    "content": "import multiprocessing\nimport shutil\nfrom multiprocessing import Pool\n\nfrom batchgenerators.utilities.file_and_folder_operations import *\n\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw\nfrom skimage import io\nfrom acvl_utils.morphology.morphology_helper import generic_filter_components\nfrom scipy.ndimage import binary_fill_holes\n\n\ndef load_and_convert_case(input_image: str, input_seg: str, output_image: str, output_seg: str,\n                          min_component_size: int = 50):\n    seg = io.imread(input_seg)\n    seg[seg == 255] = 1\n    image = io.imread(input_image)\n    image = image.sum(2)\n    mask = image == (3 * 255)\n    # the dataset has large white areas in which road segmentations can exist but no image information is available.\n    # Remove the road label in these areas\n    mask = generic_filter_components(mask, filter_fn=lambda ids, sizes: [i for j, i in enumerate(ids) if\n                                                                         sizes[j] > min_component_size])\n    mask = binary_fill_holes(mask)\n    seg[mask] = 0\n    io.imsave(output_seg, seg, check_contrast=False)\n    shutil.copy(input_image, output_image)\n\n\nif __name__ == \"__main__\":\n    # extracted archive from https://www.kaggle.com/datasets/insaff/massachusetts-roads-dataset?resource=download\n    source = '/media/fabian/data/raw_datasets/Massachussetts_road_seg/road_segmentation_ideal'\n\n    dataset_name = 'Dataset120_RoadSegmentation'\n\n    imagestr = join(nnUNet_raw, dataset_name, 'imagesTr')\n    imagests = join(nnUNet_raw, dataset_name, 'imagesTs')\n    labelstr = join(nnUNet_raw, dataset_name, 'labelsTr')\n    labelsts = join(nnUNet_raw, dataset_name, 'labelsTs')\n    maybe_mkdir_p(imagestr)\n    maybe_mkdir_p(imagests)\n    maybe_mkdir_p(labelstr)\n    maybe_mkdir_p(labelsts)\n\n    train_source = join(source, 'training')\n    test_source = join(source, 'testing')\n\n    with multiprocessing.get_context(\"spawn\").Pool(8) as p:\n\n        # not all training images have a segmentation\n        valid_ids = subfiles(join(train_source, 'output'), join=False, suffix='png')\n        num_train = len(valid_ids)\n        r = []\n        for v in valid_ids:\n            r.append(\n                p.starmap_async(\n                    load_and_convert_case,\n                    ((\n                         join(train_source, 'input', v),\n                         join(train_source, 'output', v),\n                         join(imagestr, v[:-4] + '_0000.png'),\n                         join(labelstr, v),\n                         50\n                     ),)\n                )\n            )\n\n        # test set\n        valid_ids = subfiles(join(test_source, 'output'), join=False, suffix='png')\n        for v in valid_ids:\n            r.append(\n                p.starmap_async(\n                    load_and_convert_case,\n                    ((\n                         join(test_source, 'input', v),\n                         join(test_source, 'output', v),\n                         join(imagests, v[:-4] + '_0000.png'),\n                         join(labelsts, v),\n                         50\n                     ),)\n                )\n            )\n        _ = [i.get() for i in r]\n\n    generate_dataset_json(join(nnUNet_raw, dataset_name), {0: 'R', 1: 'G', 2: 'B'}, {'background': 0, 'road': 1},\n                          num_train, '.png', dataset_name=dataset_name)\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset137_BraTS21.py",
    "content": "import multiprocessing\nimport shutil\n\nimport SimpleITK as sitk\nimport numpy as np\nfrom batchgenerators.utilities.file_and_folder_operations import *\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw\n\n\ndef copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None:\n    # use this for segmentation only!!!\n    # nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3\n    img = sitk.ReadImage(in_file)\n    img_npy = sitk.GetArrayFromImage(img)\n\n    uniques = np.unique(img_npy)\n    for u in uniques:\n        if u not in [0, 1, 2, 4]:\n            raise RuntimeError('unexpected label')\n\n    seg_new = np.zeros_like(img_npy)\n    seg_new[img_npy == 4] = 3\n    seg_new[img_npy == 2] = 1\n    seg_new[img_npy == 1] = 2\n    img_corr = sitk.GetImageFromArray(seg_new)\n    img_corr.CopyInformation(img)\n    sitk.WriteImage(img_corr, out_file)\n\n\ndef convert_labels_back_to_BraTS(seg: np.ndarray):\n    new_seg = np.zeros_like(seg)\n    new_seg[seg == 1] = 2\n    new_seg[seg == 3] = 4\n    new_seg[seg == 2] = 1\n    return new_seg\n\n\ndef load_convert_labels_back_to_BraTS(filename, input_folder, output_folder):\n    a = sitk.ReadImage(join(input_folder, filename))\n    b = sitk.GetArrayFromImage(a)\n    c = convert_labels_back_to_BraTS(b)\n    d = sitk.GetImageFromArray(c)\n    d.CopyInformation(a)\n    sitk.WriteImage(d, join(output_folder, filename))\n\n\ndef convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str, num_processes: int = 12):\n    \"\"\"\n    reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the\n    \"\"\"\n    maybe_mkdir_p(output_folder)\n    nii = subfiles(input_folder, suffix='.nii.gz', join=False)\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as p:\n        p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii)))\n\n\nif __name__ == '__main__':\n    brats_data_dir = '/home/isensee/drives/E132-Rohdaten/BraTS_2021/training'\n\n    task_id = 137\n    task_name = \"BraTS2021\"\n\n    foldername = \"Dataset%03.0d_%s\" % (task_id, task_name)\n\n    # setting up nnU-Net folders\n    out_base = join(nnUNet_raw, foldername)\n    imagestr = join(out_base, \"imagesTr\")\n    labelstr = join(out_base, \"labelsTr\")\n    maybe_mkdir_p(imagestr)\n    maybe_mkdir_p(labelstr)\n\n    case_ids = subdirs(brats_data_dir, prefix='BraTS', join=False)\n\n    for c in case_ids:\n        shutil.copy(join(brats_data_dir, c, c + \"_t1.nii.gz\"), join(imagestr, c + '_0000.nii.gz'))\n        shutil.copy(join(brats_data_dir, c, c + \"_t1ce.nii.gz\"), join(imagestr, c + '_0001.nii.gz'))\n        shutil.copy(join(brats_data_dir, c, c + \"_t2.nii.gz\"), join(imagestr, c + '_0002.nii.gz'))\n        shutil.copy(join(brats_data_dir, c, c + \"_flair.nii.gz\"), join(imagestr, c + '_0003.nii.gz'))\n\n        copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, c, c + \"_seg.nii.gz\"),\n                                                             join(labelstr, c + '.nii.gz'))\n\n    generate_dataset_json(out_base,\n                          channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'},\n                          labels={\n                              'background': 0,\n                              'whole tumor': (1, 2, 3),\n                              'tumor core': (2, 3),\n                              'enhancing tumor': (3, )\n                          },\n                          num_training_cases=len(case_ids),\n                          file_ending='.nii.gz',\n                          regions_class_order=(1, 2, 3),\n                          license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',\n                          reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',\n                          dataset_release='1.0')\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py",
    "content": "from batchgenerators.utilities.file_and_folder_operations import *\nimport shutil\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw\n\n\ndef convert_amos_task1(amos_base_dir: str, nnunet_dataset_id: int = 218):\n    \"\"\"\n    AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into\n    the train set. Having a 5-fold cross-validation is superior to a single train:val split\n    \"\"\"\n    task_name = \"AMOS2022_postChallenge_task1\"\n\n    foldername = \"Dataset%03.0d_%s\" % (nnunet_dataset_id, task_name)\n\n    # setting up nnU-Net folders\n    out_base = join(nnUNet_raw, foldername)\n    imagestr = join(out_base, \"imagesTr\")\n    imagests = join(out_base, \"imagesTs\")\n    labelstr = join(out_base, \"labelsTr\")\n    maybe_mkdir_p(imagestr)\n    maybe_mkdir_p(imagests)\n    maybe_mkdir_p(labelstr)\n\n    dataset_json_source = load_json(join(amos_base_dir, 'dataset.json'))\n\n    training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']]\n    tr_ctr = 0\n    for tr in training_identifiers:\n        if int(tr.split(\"_\")[-1]) <= 410: # these are the CT images\n            tr_ctr += 1\n            shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz'))\n            shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz'))\n\n    test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']]\n    for ts in test_identifiers:\n        if int(ts.split(\"_\")[-1]) <= 500: # these are the CT images\n            shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz'))\n\n    val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']]\n    for vl in val_identifiers:\n        if int(vl.split(\"_\")[-1]) <= 409: # these are the CT images\n            tr_ctr += 1\n            shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz'))\n            shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz'))\n\n    generate_dataset_json(out_base, {0: \"CT\"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()},\n                          num_training_cases=tr_ctr, file_ending='.nii.gz',\n                          dataset_name=task_name, reference='https://amos22.grand-challenge.org/',\n                          release='https://zenodo.org/record/7262581',\n                          overwrite_image_reader_writer='NibabelIOWithReorient',\n                          description=\"This is the dataset as released AFTER the challenge event. It has the \"\n                                      \"validation set gt in it! We just use the validation images as additional \"\n                                      \"training cases because AMOS doesn't specify how they should be used. nnU-Net's\"\n                                      \" 5-fold CV is better than some random train:val split.\")\n\n\nif __name__ == '__main__':\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('input_folder', type=str,\n                        help=\"The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. \"\n                             \"Use this link: https://zenodo.org/record/7262581.\"\n                             \"You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!\")\n    parser.add_argument('-d', required=False, type=int, default=218, help='nnU-Net Dataset ID, default: 218')\n    args = parser.parse_args()\n    amos_base = args.input_folder\n    convert_amos_task1(amos_base, args.d)\n\n\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py",
    "content": "from batchgenerators.utilities.file_and_folder_operations import *\nimport shutil\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw\n\n\ndef convert_amos_task2(amos_base_dir: str, nnunet_dataset_id: int = 219):\n    \"\"\"\n    AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into\n    the train set. Having a 5-fold cross-validation is superior to a single train:val split\n    \"\"\"\n    task_name = \"AMOS2022_postChallenge_task2\"\n\n    foldername = \"Dataset%03.0d_%s\" % (nnunet_dataset_id, task_name)\n\n    # setting up nnU-Net folders\n    out_base = join(nnUNet_raw, foldername)\n    imagestr = join(out_base, \"imagesTr\")\n    imagests = join(out_base, \"imagesTs\")\n    labelstr = join(out_base, \"labelsTr\")\n    maybe_mkdir_p(imagestr)\n    maybe_mkdir_p(imagests)\n    maybe_mkdir_p(labelstr)\n\n    dataset_json_source = load_json(join(amos_base_dir, 'dataset.json'))\n\n    training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']]\n    for tr in training_identifiers:\n        shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz'))\n        shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz'))\n\n    test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']]\n    for ts in test_identifiers:\n        shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz'))\n\n    val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']]\n    for vl in val_identifiers:\n        shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz'))\n        shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz'))\n\n    generate_dataset_json(out_base, {0: \"either_CT_or_MR\"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()},\n                          num_training_cases=len(training_identifiers) + len(val_identifiers), file_ending='.nii.gz',\n                          dataset_name=task_name, reference='https://amos22.grand-challenge.org/',\n                          release='https://zenodo.org/record/7262581',\n                          overwrite_image_reader_writer='NibabelIOWithReorient',\n                          description=\"This is the dataset as released AFTER the challenge event. It has the \"\n                                      \"validation set gt in it! We just use the validation images as additional \"\n                                      \"training cases because AMOS doesn't specify how they should be used. nnU-Net's\"\n                                      \" 5-fold CV is better than some random train:val split.\")\n\n\nif __name__ == '__main__':\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('input_folder', type=str,\n                        help=\"The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. \"\n                             \"Use this link: https://zenodo.org/record/7262581.\"\n                             \"You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!\")\n    parser.add_argument('-d', required=False, type=int, default=219, help='nnU-Net Dataset ID, default: 219')\n    args = parser.parse_args()\n    amos_base = args.input_folder\n    convert_amos_task2(amos_base, args.d)\n\n    # /home/isensee/Downloads/amos22/amos22/\n\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset220_KiTS2023.py",
    "content": "from batchgenerators.utilities.file_and_folder_operations import *\nimport shutil\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw\n\n\ndef convert_kits2023(kits_base_dir: str, nnunet_dataset_id: int = 220):\n    task_name = \"KiTS2023\"\n\n    foldername = \"Dataset%03.0d_%s\" % (nnunet_dataset_id, task_name)\n\n    # setting up nnU-Net folders\n    out_base = join(nnUNet_raw, foldername)\n    imagestr = join(out_base, \"imagesTr\")\n    labelstr = join(out_base, \"labelsTr\")\n    maybe_mkdir_p(imagestr)\n    maybe_mkdir_p(labelstr)\n\n    cases = subdirs(kits_base_dir, prefix='case_', join=False)\n    for tr in cases:\n        shutil.copy(join(kits_base_dir, tr, 'imaging.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz'))\n        shutil.copy(join(kits_base_dir, tr, 'segmentation.nii.gz'), join(labelstr, f'{tr}.nii.gz'))\n\n    generate_dataset_json(out_base, {0: \"CT\"},\n                          labels={\n                              \"background\": 0,\n                              \"kidney\": (1, 2, 3),\n                              \"masses\": (2, 3),\n                              \"tumor\": 2\n                          },\n                          regions_class_order=(1, 3, 2),\n                          num_training_cases=len(cases), file_ending='.nii.gz',\n                          dataset_name=task_name, reference='none',\n                          release='0.1.3',\n                          overwrite_image_reader_writer='NibabelIOWithReorient',\n                          description=\"KiTS2023\")\n\n\nif __name__ == '__main__':\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('input_folder', type=str,\n                        help=\"The downloaded and extracted KiTS2023 dataset (must have case_XXXXX subfolders)\")\n    parser.add_argument('-d', required=False, type=int, default=220, help='nnU-Net Dataset ID, default: 220')\n    args = parser.parse_args()\n    amos_base = args.input_folder\n    convert_kits2023(amos_base, args.d)\n\n    # /media/isensee/raw_data/raw_datasets/kits23/dataset\n\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset221_AutoPETII_2023.py",
    "content": "from batchgenerators.utilities.file_and_folder_operations import *\nimport shutil\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed\n\n\ndef convert_autopet(autopet_base_dir:str = '/media/isensee/My Book1/AutoPET/nifti/FDG-PET-CT-Lesions',\n                     nnunet_dataset_id: int = 221):\n    task_name = \"AutoPETII_2023\"\n\n    foldername = \"Dataset%03.0d_%s\" % (nnunet_dataset_id, task_name)\n\n    # setting up nnU-Net folders\n    out_base = join(nnUNet_raw, foldername)\n    imagestr = join(out_base, \"imagesTr\")\n    labelstr = join(out_base, \"labelsTr\")\n    maybe_mkdir_p(imagestr)\n    maybe_mkdir_p(labelstr)\n\n    patients = subdirs(autopet_base_dir, prefix='PETCT', join=False)\n    n = 0\n    identifiers = []\n    for pat in patients:\n        patient_acquisitions = subdirs(join(autopet_base_dir, pat), join=False)\n        for pa in patient_acquisitions:\n            n += 1\n            identifier = f\"{pat}_{pa}\"\n            identifiers.append(identifier)\n            if not isfile(join(imagestr, f'{identifier}_0000.nii.gz')):\n                shutil.copy(join(autopet_base_dir, pat, pa, 'CTres.nii.gz'), join(imagestr, f'{identifier}_0000.nii.gz'))\n            if not isfile(join(imagestr, f'{identifier}_0001.nii.gz')):\n                shutil.copy(join(autopet_base_dir, pat, pa, 'SUV.nii.gz'), join(imagestr, f'{identifier}_0001.nii.gz'))\n            if not isfile(join(imagestr, f'{identifier}.nii.gz')):\n                shutil.copy(join(autopet_base_dir, pat, pa, 'SEG.nii.gz'), join(labelstr, f'{identifier}.nii.gz'))\n\n    generate_dataset_json(out_base, {0: \"CT\", 1:\"CT\"},\n                          labels={\n                              \"background\": 0,\n                              \"tumor\": 1\n                          },\n                          num_training_cases=n, file_ending='.nii.gz',\n                          dataset_name=task_name, reference='https://autopet-ii.grand-challenge.org/',\n                          release='release',\n                          # overwrite_image_reader_writer='NibabelIOWithReorient',\n                          description=task_name)\n\n    # manual split\n    splits = []\n    for fold in range(5):\n        val_patients = patients[fold :: 5]\n        splits.append(\n            {\n                'train': [i for i in identifiers if not any([i.startswith(v) for v in val_patients])],\n                'val': [i for i in identifiers if any([i.startswith(v) for v in val_patients])],\n            }\n        )\n    pp_out_dir = join(nnUNet_preprocessed, foldername)\n    maybe_mkdir_p(pp_out_dir)\n    save_json(splits, join(pp_out_dir, 'splits_final.json'), sort_keys=False)\n\n\nif __name__ == '__main__':\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('input_folder', type=str,\n                        help=\"The downloaded and extracted autopet dataset (must have PETCT_XXX subfolders)\")\n    parser.add_argument('-d', required=False, type=int, default=221, help='nnU-Net Dataset ID, default: 221')\n    args = parser.parse_args()\n    amos_base = args.input_folder\n    convert_autopet(amos_base, args.d)\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset223_AMOS2022postChallenge.py",
    "content": "import shutil\n\nfrom batchgenerators.utilities.file_and_folder_operations import *\nfrom nnunetv2.paths import nnUNet_raw\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\n\nif __name__ == '__main__':\n    downloaded_amos_dir = '/home/isensee/amos22/amos22' # downloaded and extracted from https://zenodo.org/record/7155725#.Y0OOCOxBztM\n\n    target_dataset_id = 223\n    target_dataset_name = f'Dataset{target_dataset_id:3.0f}_AMOS2022postChallenge'\n\n    maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))\n    imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr')\n    imagesTs = join(nnUNet_raw, target_dataset_name, 'imagesTs')\n    labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr')\n    maybe_mkdir_p(imagesTr)\n    maybe_mkdir_p(imagesTs)\n    maybe_mkdir_p(labelsTr)\n\n    train_identifiers = []\n    # copy images\n    source = join(downloaded_amos_dir, 'imagesTr')\n    source_files = nifti_files(source, join=False)\n    train_identifiers += source_files\n    for s in source_files:\n        shutil.copy(join(source, s), join(imagesTr, s[:-7] + '_0000.nii.gz'))\n\n    source = join(downloaded_amos_dir, 'imagesVa')\n    source_files = nifti_files(source, join=False)\n    train_identifiers += source_files\n    for s in source_files:\n        shutil.copy(join(source, s), join(imagesTr, s[:-7] + '_0000.nii.gz'))\n\n    source = join(downloaded_amos_dir, 'imagesTs')\n    source_files = nifti_files(source, join=False)\n    for s in source_files:\n        shutil.copy(join(source, s), join(imagesTs, s[:-7] + '_0000.nii.gz'))\n\n    # copy labels\n    source = join(downloaded_amos_dir, 'labelsTr')\n    source_files = nifti_files(source, join=False)\n    for s in source_files:\n        shutil.copy(join(source, s), join(labelsTr, s))\n\n    source = join(downloaded_amos_dir, 'labelsVa')\n    source_files = nifti_files(source, join=False)\n    for s in source_files:\n        shutil.copy(join(source, s), join(labelsTr, s))\n\n    old_dataset_json = load_json(join(downloaded_amos_dir, 'dataset.json'))\n    new_labels = {v: k for k, v in old_dataset_json['labels'].items()}\n\n    generate_dataset_json(join(nnUNet_raw, target_dataset_name), {0: 'nonCT'}, new_labels,\n                          num_training_cases=len(train_identifiers), file_ending='.nii.gz', regions_class_order=None,\n                          dataset_name=target_dataset_name, reference='https://zenodo.org/record/7155725#.Y0OOCOxBztM',\n                          license=old_dataset_json['licence'],  # typo in OG dataset.json\n                          description=old_dataset_json['description'],\n                          release=old_dataset_json['release'])\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset224_AbdomenAtlas1.0.py",
    "content": "from batchgenerators.utilities.file_and_folder_operations import *\nimport shutil\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\n\n\nif __name__ == '__main__':\n    \"\"\"\n    How to train our submission to the JHU benchmark\n    \n    1. Execute this script here to convert the dataset into nnU-Net format. Adapt the paths to your system!\n    2. Run planning and preprocessing: `nnUNetv2_plan_and_preprocess -d 224 -npfp 64 -np 64 -c 3d_fullres -pl \n    nnUNetPlannerResEncL_torchres`. Adapt the number of processes to your System (-np; -npfp)! Note that each process \n    will again spawn 4 threads for resampling. This custom planner replaces the nnU-Net default resampling scheme with \n    a torch-based implementation which is faster but less accurate. This is needed to satisfy the inference speed \n    constraints.\n    3. Run training with `nnUNetv2_train 224 3d_fullres all -p nnUNetResEncUNetLPlans_torchres`. 24GB VRAM required, \n    training will take ~28-30h.\n    \"\"\"\n\n\n    base = '/home/isensee/Downloads/AbdomenAtlas1.0Mini'\n    cases = subdirs(base, join=False, prefix='BDMAP')\n\n    target_dataset_id = 224\n    target_dataset_name = f'Dataset{target_dataset_id:3.0f}_AbdomenAtlas1.0'\n\n    raw_dir = '/home/isensee/drives/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/2024_JHU_benchmark'\n    maybe_mkdir_p(join(raw_dir, target_dataset_name))\n    imagesTr = join(raw_dir, target_dataset_name, 'imagesTr')\n    labelsTr = join(raw_dir, target_dataset_name, 'labelsTr')\n    maybe_mkdir_p(imagesTr)\n    maybe_mkdir_p(labelsTr)\n\n    for case in cases:\n        shutil.copy(join(base, case, 'ct.nii.gz'), join(imagesTr, case + '_0000.nii.gz'))\n        shutil.copy(join(base, case, 'combined_labels.nii.gz'), join(labelsTr, case + '.nii.gz'))\n\n    labels = {\n        \"background\": 0,\n        \"aorta\": 1,\n        \"gall_bladder\": 2,\n        \"kidney_left\": 3,\n        \"kidney_right\": 4,\n        \"liver\": 5,\n        \"pancreas\": 6,\n        \"postcava\": 7,\n        \"spleen\": 8,\n        \"stomach\": 9\n    }\n\n    generate_dataset_json(\n        join(raw_dir, target_dataset_name),\n        {0: 'nonCT'},  # this was a mistake we did at the beginning and we keep it like that here for consistency\n        labels,\n        len(cases),\n        '.nii.gz',\n        None,\n        target_dataset_name,\n        overwrite_image_reader_writer='NibabelIOWithReorient'\n    )"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset226_BraTS2024-BraTS-GLI.py",
    "content": "from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom batchgenerators.utilities.file_and_folder_operations import join, subdirs, subfiles, maybe_mkdir_p\nfrom nnunetv2.paths import nnUNet_raw\n\nif __name__ == '__main__':\n    \"\"\"\n    this dataset does not copy the data into nnunet format and just links to existing data. The dataset can only be \n    used from one machine because the paths in the dataset.json are hard coded\n    \"\"\"\n    extracted_BraTS2024_GLI_dir = '/home/isensee/BraTS2024_traindata/training_data1'\n    nnunet_dataset_name = 'BraTS2024-BraTS-GLI'\n    nnunet_dataset_id = 226\n    dataset_name = f'Dataset{nnunet_dataset_id:03d}_{nnunet_dataset_name}'\n    dataset_dir = join(nnUNet_raw, dataset_name)\n    maybe_mkdir_p(dataset_dir)\n\n    dataset = {}\n    casenames = subdirs(extracted_BraTS2024_GLI_dir, join=False)\n    for c in casenames:\n        dataset[c] = {\n            'label': join(extracted_BraTS2024_GLI_dir, c, c + '-seg.nii.gz'),\n            'images': [\n                join(extracted_BraTS2024_GLI_dir, c, c + '-t1n.nii.gz'),\n                join(extracted_BraTS2024_GLI_dir, c, c + '-t1c.nii.gz'),\n                join(extracted_BraTS2024_GLI_dir, c, c + '-t2w.nii.gz'),\n                join(extracted_BraTS2024_GLI_dir, c, c + '-t2f.nii.gz')\n            ]\n        }\n    labels = {\n        'background': 0,\n        'NETC': 1,\n        'SNFH': 2,\n        'ET': 3,\n        'RC': 4,\n    }\n\n    generate_dataset_json(\n        dataset_dir,\n        {\n            0: 'T1',\n            1: \"T1C\",\n            2: \"T2W\",\n            3: \"T2F\"\n        },\n        labels,\n        num_training_cases=len(dataset),\n        file_ending='.nii.gz',\n        regions_class_order=None,\n        dataset_name=dataset_name,\n        reference='https://www.synapse.org/Synapse:syn53708249/wiki/627500',\n        license='see https://www.synapse.org/Synapse:syn53708249/wiki/627508',\n        dataset=dataset,\n        description='This dataset does not copy the data into nnunet format and just links to existing data. '\n                    'The dataset can only be used from one machine because the paths in the dataset.json are hard coded'\n    )\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset227_TotalSegmentatorMRI.py",
    "content": "import SimpleITK\nimport nibabel\nimport numpy as np\nfrom batchgenerators.utilities.file_and_folder_operations import *\nimport shutil\nfrom nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json\nfrom nnunetv2.paths import nnUNet_raw\n\n\n\nif __name__ == '__main__':\n    base = '/home/isensee/Downloads/TotalsegmentatorMRI_dataset_v100'\n    cases = subdirs(base, join=False)\n\n    target_dataset_id = 227\n    target_dataset_name = f'Dataset{target_dataset_id:3.0f}_TotalSegmentatorMRI'\n\n    maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))\n    imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr')\n    labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr')\n    maybe_mkdir_p(imagesTr)\n    maybe_mkdir_p(labelsTr)\n\n    # discover labels\n    label_fnames = nifti_files(join(base, cases[0], 'segmentations'), join=False)\n    label_dict = {i[:-7]: j + 1 for j, i in enumerate(label_fnames)}\n    labelnames = list(label_dict.keys())\n    label_dict['background'] = 0\n\n    for case in cases:\n        img = nibabel.load(join(base, case, 'mri.nii.gz'))\n        nibabel.save(img, join(imagesTr, case + '_0000.nii.gz'))\n\n        seg_nib = nibabel.load(join(base, case, 'segmentations', labelnames[0] + '.nii.gz'))\n        init_seg_npy = np.asanyarray(seg_nib.dataobj)\n        init_seg_npy[init_seg_npy > 0] = label_dict[labelnames[0]]\n        for labelname in labelnames[1:]:\n            seg = nibabel.load(join(base, case, 'segmentations', labelname + '.nii.gz'))\n            seg = np.asanyarray(seg.dataobj)\n            init_seg_npy[seg > 0] = label_dict[labelname]\n        out = nibabel.Nifti1Image(init_seg_npy, affine=seg_nib.affine, header=seg_nib.header)\n        nibabel.save(out, join(labelsTr, case + '.nii.gz'))\n\n    generate_dataset_json(\n        join(nnUNet_raw, target_dataset_name),\n        {0: 'MRI'},  # this was a mistake we did at the beginning and we keep it like that here for consistency\n        label_dict,\n        len(cases),\n        '.nii.gz',\n        None,\n        target_dataset_name,\n        overwrite_image_reader_writer='NibabelIOWithReorient',\n        release='1.0.0',\n        reference='https://zenodo.org/records/11367005',\n        license='see reference'\n    )"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset987_dummyDataset4.py",
    "content": "import os\n\nfrom batchgenerators.utilities.file_and_folder_operations import *\n\nfrom nnunetv2.paths import nnUNet_raw\nfrom nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets\n\nif __name__ == '__main__':\n    # creates a dummy dataset where there are no files in imagestr and labelstr\n    source_dataset = 'Dataset004_Hippocampus'\n\n    target_dataset = 'Dataset987_dummyDataset4'\n    target_dataset_dir = join(nnUNet_raw, target_dataset)\n    maybe_mkdir_p(target_dataset_dir)\n\n    dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, source_dataset))\n\n    # the returned dataset will have absolute paths. We should use relative paths so that you can freely copy\n    # datasets around between systems. As long as the source dataset is there it will continue working even if\n    # nnUNet_raw is in different locations\n\n    # paths must be relative to target_dataset_dir!!!\n    for k in dataset.keys():\n        dataset[k]['label'] = os.path.relpath(dataset[k]['label'], target_dataset_dir)\n        dataset[k]['images'] = [os.path.relpath(i, target_dataset_dir) for i in dataset[k]['images']]\n\n    # load old dataset.json\n    dataset_json = load_json(join(nnUNet_raw, source_dataset, 'dataset.json'))\n    dataset_json['dataset'] = dataset\n\n    # save\n    save_json(dataset_json, join(target_dataset_dir, 'dataset.json'), sort_keys=False)\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/Dataset989_dummyDataset4_2.py",
    "content": "import os\n\nfrom batchgenerators.utilities.file_and_folder_operations import *\n\nfrom nnunetv2.paths import nnUNet_raw\nfrom nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets\n\nif __name__ == '__main__':\n    # creates a dummy dataset where there are no files in imagestr and labelstr\n    source_dataset = 'Dataset004_Hippocampus'\n\n    target_dataset = 'Dataset989_dummyDataset4_2'\n    target_dataset_dir = join(nnUNet_raw, target_dataset)\n    maybe_mkdir_p(target_dataset_dir)\n\n    dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, source_dataset))\n\n    # the returned dataset will have absolute paths. We should use relative paths so that you can freely copy\n    # datasets around between systems. As long as the source dataset is there it will continue working even if\n    # nnUNet_raw is in different locations\n\n    # paths must be relative to target_dataset_dir!!!\n    for k in dataset.keys():\n        dataset[k]['label'] = join('$nnUNet_raw', os.path.relpath(dataset[k]['label'], nnUNet_raw))\n        dataset[k]['images'] = [join('$nnUNet_raw', os.path.relpath(i,  nnUNet_raw)) for i in dataset[k]['images']]\n\n    # load old dataset.json\n    dataset_json = load_json(join(nnUNet_raw, source_dataset, 'dataset.json'))\n    dataset_json['dataset'] = dataset\n\n    # save\n    save_json(dataset_json, join(target_dataset_dir, 'dataset.json'), sort_keys=False)\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/dataset_conversion/convert_MSD_dataset.py",
    "content": "import argparse\nimport multiprocessing\nimport shutil\nfrom typing import Optional\nimport SimpleITK as sitk\nfrom batchgenerators.utilities.file_and_folder_operations import *\nfrom nnunetv2.paths import nnUNet_raw\nfrom nnunetv2.utilities.dataset_name_id_conversion import find_candidate_datasets\nfrom nnunetv2.configuration import default_num_processes\nimport numpy as np\n\n\ndef split_4d_nifti(filename, output_folder):\n    img_itk = sitk.ReadImage(filename)\n    dim = img_itk.GetDimension()\n    file_base = os.path.basename(filename)\n    if dim == 3:\n        shutil.copy(filename, join(output_folder, file_base[:-7] + \"_0000.nii.gz\"))\n        return\n    elif dim != 4:\n        raise RuntimeError(\"Unexpected dimensionality: %d of file %s, cannot split\" % (dim, filename))\n    else:\n        img_npy = sitk.GetArrayFromImage(img_itk)\n        spacing = img_itk.GetSpacing()\n        origin = img_itk.GetOrigin()\n        direction = np.array(img_itk.GetDirection()).reshape(4,4)\n        # now modify these to remove the fourth dimension\n        spacing = tuple(list(spacing[:-1]))\n        origin = tuple(list(origin[:-1]))\n        direction = tuple(direction[:-1, :-1].reshape(-1))\n        for i, t in enumerate(range(img_npy.shape[0])):\n            img = img_npy[t]\n            img_itk_new = sitk.GetImageFromArray(img)\n            img_itk_new.SetSpacing(spacing)\n            img_itk_new.SetOrigin(origin)\n            img_itk_new.SetDirection(direction)\n            sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + \"_%04.0d.nii.gz\" % i))\n\n\ndef convert_msd_dataset(source_folder: str, overwrite_target_id: Optional[int] = None,\n                        num_processes: int = default_num_processes) -> None:\n    if source_folder.endswith('/') or source_folder.endswith('\\\\'):\n        source_folder = source_folder[:-1]\n\n    labelsTr = join(source_folder, 'labelsTr')\n    imagesTs = join(source_folder, 'imagesTs')\n    imagesTr = join(source_folder, 'imagesTr')\n    assert isdir(labelsTr), f\"labelsTr subfolder missing in source folder\"\n    assert isdir(imagesTs), f\"imagesTs subfolder missing in source folder\"\n    assert isdir(imagesTr), f\"imagesTr subfolder missing in source folder\"\n    dataset_json = join(source_folder, 'dataset.json')\n    assert isfile(dataset_json), f\"dataset.json missing in source_folder\"\n\n    # infer source dataset id and name\n    task, dataset_name = os.path.basename(source_folder).split('_')\n    task_id = int(task[4:])\n\n    # check if target dataset id is taken\n    target_id = task_id if overwrite_target_id is None else overwrite_target_id\n    existing_datasets = find_candidate_datasets(target_id)\n    assert len(existing_datasets) == 0, f\"Target dataset id {target_id} is already taken, please consider changing \" \\\n                                        f\"it using overwrite_target_id. Conflicting dataset: {existing_datasets} (check nnUNet_results, nnUNet_preprocessed and nnUNet_raw!)\"\n\n    target_dataset_name = f\"Dataset{target_id:03d}_{dataset_name}\"\n    target_folder = join(nnUNet_raw, target_dataset_name)\n    target_imagesTr = join(target_folder, 'imagesTr')\n    target_imagesTs = join(target_folder, 'imagesTs')\n    target_labelsTr = join(target_folder, 'labelsTr')\n    maybe_mkdir_p(target_imagesTr)\n    maybe_mkdir_p(target_imagesTs)\n    maybe_mkdir_p(target_labelsTr)\n\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as p:\n        results = []\n\n        # convert 4d train images\n        source_images = [i for i in subfiles(imagesTr, suffix='.nii.gz', join=False) if\n                         not i.startswith('.') and not i.startswith('_')]\n        source_images = [join(imagesTr, i) for i in source_images]\n\n        results.append(\n            p.starmap_async(\n                split_4d_nifti, zip(source_images, [target_imagesTr] * len(source_images))\n            )\n        )\n\n        # convert 4d test images\n        source_images = [i for i in subfiles(imagesTs, suffix='.nii.gz', join=False) if\n                         not i.startswith('.') and not i.startswith('_')]\n        source_images = [join(imagesTs, i) for i in source_images]\n\n        results.append(\n            p.starmap_async(\n                split_4d_nifti, zip(source_images, [target_imagesTs] * len(source_images))\n            )\n        )\n\n        # copy segmentations\n        source_images = [i for i in subfiles(labelsTr, suffix='.nii.gz', join=False) if\n                         not i.startswith('.') and not i.startswith('_')]\n        for s in source_images:\n            shutil.copy(join(labelsTr, s), join(target_labelsTr, s))\n\n        [i.get() for i in results]\n\n    dataset_json = load_json(dataset_json)\n    dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()}\n    dataset_json['file_ending'] = \".nii.gz\"\n    dataset_json[\"channel_names\"] = dataset_json[\"modality\"]\n    del dataset_json[\"modality\"]\n    del dataset_json[\"training\"]\n    del dataset_json[\"test\"]\n    save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False)\n\n\ndef entry_point():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-i', type=str, required=True,\n                        help='Downloaded and extracted MSD dataset folder. CANNOT be nnUNetv1 dataset! Example: '\n                             '/home/fabian/Downloads/Task05_Prostate')\n    parser.add_argument('-overwrite_id', type=int, required=False, default=None,\n                        help='Overwrite the dataset id. If not set we use the id of the MSD task (inferred from '\n                             'folder name). Only use this if you already have an equivalently numbered dataset!')\n    parser.add_argument('-np', type=int, required=False, default=default_num_processes,\n                        help=f'Number of processes used. Default: {default_num_processes}')\n    args = parser.parse_args()\n    convert_msd_dataset(args.i, args.overwrite_id, args.np)\n\n\nif __name__ == '__main__':\n    convert_msd_dataset('/home/fabian/Downloads/Task05_Prostate', overwrite_target_id=201)\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py",
    "content": "import shutil\nfrom copy import deepcopy\n\nfrom batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, isdir, load_json, save_json\nfrom nnunetv2.paths import nnUNet_raw\n\n\ndef convert(source_folder, target_dataset_name):\n    \"\"\"\n    remember that old tasks were called TaskXXX_YYY and new ones are called DatasetXXX_YYY\n    source_folder\n    \"\"\"\n    if isdir(join(nnUNet_raw, target_dataset_name)):\n        raise RuntimeError(f'Target dataset name {target_dataset_name} already exists. Aborting... '\n                           f'(we might break something). If you are sure you want to proceed, please manually '\n                           f'delete {join(nnUNet_raw, target_dataset_name)}')\n    maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))\n    shutil.copytree(join(source_folder, 'imagesTr'), join(nnUNet_raw, target_dataset_name, 'imagesTr'))\n    shutil.copytree(join(source_folder, 'labelsTr'), join(nnUNet_raw, target_dataset_name, 'labelsTr'))\n    if isdir(join(source_folder, 'imagesTs')):\n        shutil.copytree(join(source_folder, 'imagesTs'), join(nnUNet_raw, target_dataset_name, 'imagesTs'))\n    if isdir(join(source_folder, 'labelsTs')):\n        shutil.copytree(join(source_folder, 'labelsTs'), join(nnUNet_raw, target_dataset_name, 'labelsTs'))\n    if isdir(join(source_folder, 'imagesVal')):\n        shutil.copytree(join(source_folder, 'imagesVal'), join(nnUNet_raw, target_dataset_name, 'imagesVal'))\n    if isdir(join(source_folder, 'labelsVal')):\n        shutil.copytree(join(source_folder, 'labelsVal'), join(nnUNet_raw, target_dataset_name, 'labelsVal'))\n    shutil.copy(join(source_folder, 'dataset.json'), join(nnUNet_raw, target_dataset_name))\n\n    dataset_json = load_json(join(nnUNet_raw, target_dataset_name, 'dataset.json'))\n    del dataset_json['tensorImageSize']\n    del dataset_json['numTest']\n    del dataset_json['training']\n    del dataset_json['test']\n    dataset_json['channel_names'] = deepcopy(dataset_json['modality'])\n    del dataset_json['modality']\n\n    dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()}\n    dataset_json['file_ending'] = \".nii.gz\"\n    save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False)\n\n\ndef convert_entry_point():\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"input_folder\", type=str,\n                        help='Raw old nnUNet dataset. This must be the folder with imagesTr,labelsTr etc subfolders! '\n                             'Please provide the PATH to the old Task, not just the task name. nnU-Net V2 does not '\n                             'know where v1 tasks are.')\n    parser.add_argument(\"output_dataset_name\", type=str,\n                        help='New dataset NAME (not path!). Must follow the DatasetXXX_NAME convention!')\n    args = parser.parse_args()\n    convert(args.input_folder, args.output_dataset_name)\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py",
    "content": "import SimpleITK as sitk\nimport shutil\n\nimport numpy as np\nfrom batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json, nifti_files\n\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\nfrom nnunetv2.paths import nnUNet_raw\nfrom nnunetv2.utilities.label_handling.label_handling import LabelManager\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager\n\n\ndef sparsify_segmentation(seg: np.ndarray, label_manager: LabelManager, percent_of_slices: float) -> np.ndarray:\n        assert label_manager.has_ignore_label, \"This preprocessor only works with datasets that have an ignore label!\"\n        seg_new = np.ones_like(seg) * label_manager.ignore_label\n        x, y, z = seg.shape\n        # x\n        num_slices = max(1, round(x * percent_of_slices))\n        selected_slices = np.random.choice(x, num_slices, replace=False)\n        seg_new[selected_slices] = seg[selected_slices]\n        # y\n        num_slices = max(1, round(y * percent_of_slices))\n        selected_slices = np.random.choice(y, num_slices, replace=False)\n        seg_new[:, selected_slices] = seg[:, selected_slices]\n        # z\n        num_slices = max(1, round(z * percent_of_slices))\n        selected_slices = np.random.choice(z, num_slices, replace=False)\n        seg_new[:, :, selected_slices] = seg[:, :, selected_slices]\n        return seg_new\n\n\nif __name__ == '__main__':\n    dataset_name = 'IntegrationTest_Hippocampus_regions_ignore'\n    dataset_id = 996\n    dataset_name = f\"Dataset{dataset_id:03d}_{dataset_name}\"\n\n    try:\n        existing_dataset_name = maybe_convert_to_dataset_name(dataset_id)\n        if existing_dataset_name != dataset_name:\n            raise FileExistsError(f\"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If \"\n                               f\"you intent to delete it, remember to also remove it in nnUNet_preprocessed and \"\n                               f\"nnUNet_results!\")\n    except RuntimeError:\n        pass\n\n    if isdir(join(nnUNet_raw, dataset_name)):\n        shutil.rmtree(join(nnUNet_raw, dataset_name))\n\n    source_dataset = maybe_convert_to_dataset_name(4)\n    shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name))\n\n    # additionally optimize entire hippocampus region, remove Posterior\n    dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json'))\n    dj['labels'] = {\n        'background': 0,\n        'hippocampus': (1, 2),\n        'anterior': 1,\n        'ignore': 3\n    }\n    dj['regions_class_order'] = (2, 1)\n    save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False)\n\n    # now add ignore label to segmentation images\n    np.random.seed(1234)\n    lm = LabelManager(label_dict=dj['labels'], regions_class_order=dj.get('regions_class_order'))\n\n    segs = nifti_files(join(nnUNet_raw, dataset_name, 'labelsTr'))\n    for s in segs:\n        seg_itk = sitk.ReadImage(s)\n        seg_npy = sitk.GetArrayFromImage(seg_itk)\n        seg_npy = sparsify_segmentation(seg_npy, lm, 0.1 / 3)\n        seg_itk_new = sitk.GetImageFromArray(seg_npy)\n        seg_itk_new.CopyInformation(seg_itk)\n        sitk.WriteImage(seg_itk_new, s)\n\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py",
    "content": "import shutil\n\nfrom batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json\n\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\nfrom nnunetv2.paths import nnUNet_raw\n\nif __name__ == '__main__':\n    dataset_name = 'IntegrationTest_Hippocampus_regions'\n    dataset_id = 997\n    dataset_name = f\"Dataset{dataset_id:03d}_{dataset_name}\"\n\n    try:\n        existing_dataset_name = maybe_convert_to_dataset_name(dataset_id)\n        if existing_dataset_name != dataset_name:\n            raise FileExistsError(\n                f\"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If \"\n                f\"you intent to delete it, remember to also remove it in nnUNet_preprocessed and \"\n                f\"nnUNet_results!\")\n    except RuntimeError:\n        pass\n\n    if isdir(join(nnUNet_raw, dataset_name)):\n        shutil.rmtree(join(nnUNet_raw, dataset_name))\n\n    source_dataset = maybe_convert_to_dataset_name(4)\n    shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name))\n\n    # additionally optimize entire hippocampus region, remove Posterior\n    dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json'))\n    dj['labels'] = {\n        'background': 0,\n        'hippocampus': (1, 2),\n        'anterior': 1\n    }\n    dj['regions_class_order'] = (2, 1)\n    save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False)\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py",
    "content": "import shutil\n\nfrom batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json\n\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\nfrom nnunetv2.paths import nnUNet_raw\n\n\nif __name__ == '__main__':\n    dataset_name = 'IntegrationTest_Hippocampus_ignore'\n    dataset_id = 998\n    dataset_name = f\"Dataset{dataset_id:03d}_{dataset_name}\"\n\n    try:\n        existing_dataset_name = maybe_convert_to_dataset_name(dataset_id)\n        if existing_dataset_name != dataset_name:\n            raise FileExistsError(f\"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If \"\n                               f\"you intent to delete it, remember to also remove it in nnUNet_preprocessed and \"\n                               f\"nnUNet_results!\")\n    except RuntimeError:\n        pass\n\n    if isdir(join(nnUNet_raw, dataset_name)):\n        shutil.rmtree(join(nnUNet_raw, dataset_name))\n\n    source_dataset = maybe_convert_to_dataset_name(4)\n    shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name))\n\n    # set class 2 to ignore label\n    dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json'))\n    dj['labels']['ignore'] = 2\n    del dj['labels']['Posterior']\n    save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False)\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py",
    "content": "import shutil\n\nfrom batchgenerators.utilities.file_and_folder_operations import isdir, join\n\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\nfrom nnunetv2.paths import nnUNet_raw\n\n\nif __name__ == '__main__':\n    dataset_name = 'IntegrationTest_Hippocampus'\n    dataset_id = 999\n    dataset_name = f\"Dataset{dataset_id:03d}_{dataset_name}\"\n\n    try:\n        existing_dataset_name = maybe_convert_to_dataset_name(dataset_id)\n        if existing_dataset_name != dataset_name:\n            raise FileExistsError(f\"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If \"\n                               f\"you intent to delete it, remember to also remove it in nnUNet_preprocessed and \"\n                               f\"nnUNet_results!\")\n    except RuntimeError:\n        pass\n\n    if isdir(join(nnUNet_raw, dataset_name)):\n        shutil.rmtree(join(nnUNet_raw, dataset_name))\n\n    source_dataset = maybe_convert_to_dataset_name(4)\n    shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name))\n"
  },
  {
    "path": "nnunetv2/dataset_conversion/datasets_for_integration_tests/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/dataset_conversion/generate_dataset_json.py",
    "content": "from typing import Tuple, Union, List\n\nfrom batchgenerators.utilities.file_and_folder_operations import save_json, join\n\n\ndef generate_dataset_json(output_folder: str,\n                          channel_names: dict,\n                          labels: dict,\n                          num_training_cases: int,\n                          file_ending: str,\n                          citation: Union[List[str], str] = None,\n                          regions_class_order: Tuple[int, ...] = None,\n                          dataset_name: str = None,\n                          reference: str = None,\n                          release: str = None,\n                          description: str = None,\n                          overwrite_image_reader_writer: str = None,\n                          license: str = 'Whoever converted this dataset was lazy and didn\\'t look it up!',\n                          converted_by: str = \"Please enter your name, especially when sharing datasets with others in a common infrastructure!\",\n                          **kwargs):\n    \"\"\"\n    Generates a dataset.json file in the output folder\n\n    channel_names:\n        Channel names must map the index to the name of the channel, example:\n        {\n            0: 'T1',\n            1: 'CT'\n        }\n        Note that the channel names may influence the normalization scheme!! Learn more in the documentation.\n\n    labels:\n        This will tell nnU-Net what labels to expect. Important: This will also determine whether you use region-based training or not.\n        Example regular labels:\n        {\n            'background': 0,\n            'left atrium': 1,\n            'some other label': 2\n        }\n        Example region-based training:\n        {\n            'background': 0,\n            'whole tumor': (1, 2, 3),\n            'tumor core': (2, 3),\n            'enhancing tumor': 3\n        }\n\n        Remember that nnU-Net expects consecutive values for labels! nnU-Net also expects 0 to be background!\n\n    num_training_cases: is used to double check all cases are there!\n\n    file_ending: needed for finding the files correctly. IMPORTANT! File endings must match between images and\n    segmentations!\n\n    dataset_name, reference, release, license, description: self-explanatory and not used by nnU-Net. Just for\n    completeness and as a reminder that these would be great!\n\n    overwrite_image_reader_writer: If you need a special IO class for your dataset you can derive it from\n    BaseReaderWriter, place it into nnunet.imageio and reference it here by name\n\n    kwargs: whatever you put here will be placed in the dataset.json as well\n\n    \"\"\"\n    has_regions: bool = any([isinstance(i, (tuple, list)) and len(i) > 1 for i in labels.values()])\n    if has_regions:\n        assert regions_class_order is not None, f\"You have defined regions but regions_class_order is not set. \" \\\n                                                f\"You need that.\"\n    # channel names need strings as keys\n    keys = list(channel_names.keys())\n    for k in keys:\n        if not isinstance(k, str):\n            channel_names[str(k)] = channel_names[k]\n            del channel_names[k]\n\n    # labels need ints as values\n    for l in labels.keys():\n        value = labels[l]\n        if isinstance(value, (tuple, list)):\n            value = tuple([int(i) for i in value])\n            labels[l] = value\n        else:\n            labels[l] = int(labels[l])\n\n    dataset_json = {\n        'channel_names': channel_names,  # previously this was called 'modality'. I didn't like this so this is\n        # channel_names now. Live with it.\n        'labels': labels,\n        'numTraining': num_training_cases,\n        'file_ending': file_ending,\n        'licence': license,\n        'converted_by': converted_by\n    }\n\n    if dataset_name is not None:\n        dataset_json['name'] = dataset_name\n    if reference is not None:\n        dataset_json['reference'] = reference\n    if release is not None:\n        dataset_json['release'] = release\n    if citation is not None:\n        dataset_json['citation'] = release\n    if description is not None:\n        dataset_json['description'] = description\n    if overwrite_image_reader_writer is not None:\n        dataset_json['overwrite_image_reader_writer'] = overwrite_image_reader_writer\n    if regions_class_order is not None:\n        dataset_json['regions_class_order'] = regions_class_order\n\n    dataset_json.update(kwargs)\n\n    save_json(dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)\n"
  },
  {
    "path": "nnunetv2/ensembling/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/ensembling/ensemble.py",
    "content": "import argparse\nimport multiprocessing\nimport shutil\nfrom copy import deepcopy\nfrom typing import List, Union, Tuple\n\nimport numpy as np\nfrom batchgenerators.utilities.file_and_folder_operations import load_json, join, subfiles, \\\n    maybe_mkdir_p, isdir, save_pickle, load_pickle, isfile\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.imageio.base_reader_writer import BaseReaderWriter\nfrom nnunetv2.utilities.label_handling.label_handling import LabelManager\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager\n\n\ndef average_probabilities(list_of_files: List[str]) -> np.ndarray:\n    assert len(list_of_files), 'At least one file must be given in list_of_files'\n    avg = None\n    for f in list_of_files:\n        if avg is None:\n            avg = np.load(f)['probabilities']\n            # maybe increase precision to prevent rounding errors\n            if avg.dtype != np.float32:\n                avg = avg.astype(np.float32)\n        else:\n            avg += np.load(f)['probabilities']\n    avg /= len(list_of_files)\n    return avg\n\n\ndef merge_files(list_of_files,\n                output_filename_truncated: str,\n                output_file_ending: str,\n                image_reader_writer: BaseReaderWriter,\n                label_manager: LabelManager,\n                save_probabilities: bool = False):\n    # load the pkl file associated with the first file in list_of_files\n    properties = load_pickle(list_of_files[0][:-4] + '.pkl')\n    # load and average predictions\n    probabilities = average_probabilities(list_of_files)\n    segmentation = label_manager.convert_logits_to_segmentation(probabilities)\n    image_reader_writer.write_seg(segmentation, output_filename_truncated + output_file_ending, properties)\n    if save_probabilities:\n        np.savez_compressed(output_filename_truncated + '.npz', probabilities=probabilities)\n        save_pickle(probabilities, output_filename_truncated + '.pkl')\n\n\ndef ensemble_folders(list_of_input_folders: List[str],\n                     output_folder: str,\n                     save_merged_probabilities: bool = False,\n                     num_processes: int = default_num_processes,\n                     dataset_json_file_or_dict: str = None,\n                     plans_json_file_or_dict: str = None):\n    \"\"\"we need too much shit for this function. Problem is that we now have to support region-based training plus\n    multiple input/output formats so there isn't really a way around this.\n\n    If plans and dataset json are not specified, we assume each of the folders has a corresponding plans.json\n    and/or dataset.json in it. These are usually copied into those folders by nnU-Net during prediction.\n    We just pick the dataset.json and plans.json from the first of the folders and we DONT check whether the 5\n    folders contain the same plans etc! This can be a feature if results from different datasets are to be merged (only\n    works if label dict in dataset.json is the same between these datasets!!!)\"\"\"\n    if dataset_json_file_or_dict is not None:\n        if isinstance(dataset_json_file_or_dict, str):\n            dataset_json = load_json(dataset_json_file_or_dict)\n        else:\n            dataset_json = dataset_json_file_or_dict\n    else:\n        dataset_json = load_json(join(list_of_input_folders[0], 'dataset.json'))\n\n    if plans_json_file_or_dict is not None:\n        if isinstance(plans_json_file_or_dict, str):\n            plans = load_json(plans_json_file_or_dict)\n        else:\n            plans = plans_json_file_or_dict\n    else:\n        plans = load_json(join(list_of_input_folders[0], 'plans.json'))\n\n    plans_manager = PlansManager(plans)\n\n    # now collect the files in each of the folders and enforce that all files are present in all folders\n    files_per_folder = [set(subfiles(i, suffix='.npz', join=False)) for i in list_of_input_folders]\n    # first build a set with all files\n    s = deepcopy(files_per_folder[0])\n    for f in files_per_folder[1:]:\n        s.update(f)\n    for f in files_per_folder:\n        assert len(s.difference(f)) == 0, \"Not all folders contain the same files for ensembling. Please only \" \\\n                                          \"provide folders that contain the predictions\"\n    lists_of_lists_of_files = [[join(fl, fi) for fl in list_of_input_folders] for fi in s]\n    output_files_truncated = [join(output_folder, fi[:-4]) for fi in s]\n\n    image_reader_writer = plans_manager.image_reader_writer_class()\n    label_manager = plans_manager.get_label_manager(dataset_json)\n\n    maybe_mkdir_p(output_folder)\n    shutil.copy(join(list_of_input_folders[0], 'dataset.json'), output_folder)\n\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as pool:\n        num_preds = len(s)\n        _ = pool.starmap(\n            merge_files,\n            zip(\n                lists_of_lists_of_files,\n                output_files_truncated,\n                [dataset_json['file_ending']] * num_preds,\n                [image_reader_writer] * num_preds,\n                [label_manager] * num_preds,\n                [save_merged_probabilities] * num_preds\n            )\n        )\n\n\ndef entry_point_ensemble_folders():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-i', nargs='+', type=str, required=True,\n                        help='list of input folders')\n    parser.add_argument('-o', type=str, required=True, help='output folder')\n    parser.add_argument('-np', type=int, required=False, default=default_num_processes,\n                        help=f\"Numbers of processes used for ensembling. Default: {default_num_processes}\")\n    parser.add_argument('--save_npz', action='store_true', required=False, help='Set this flag to store output '\n                                                                                'probabilities in separate .npz files')\n\n    args = parser.parse_args()\n    ensemble_folders(args.i, args.o, args.save_npz, args.np)\n\n\ndef ensemble_crossvalidations(list_of_trained_model_folders: List[str],\n                              output_folder: str,\n                              folds: Union[Tuple[int, ...], List[int]] = (0, 1, 2, 3, 4),\n                              num_processes: int = default_num_processes,\n                              overwrite: bool = True) -> None:\n    \"\"\"\n    Feature: different configurations can now have different splits\n    \"\"\"\n    dataset_json = load_json(join(list_of_trained_model_folders[0], 'dataset.json'))\n    plans_manager = PlansManager(join(list_of_trained_model_folders[0], 'plans.json'))\n\n    # first collect all unique filenames\n    files_per_folder = {}\n    unique_filenames = set()\n    for tr in list_of_trained_model_folders:\n        files_per_folder[tr] = {}\n        for f in folds:\n            if not isdir(join(tr, f'fold_{f}', 'validation')):\n                raise RuntimeError(f'Expected model output directory does not exist. You must train all requested '\n                                   f'folds of the specified model.\\nModel: {tr}\\nFold: {f}')\n            files_here = subfiles(join(tr, f'fold_{f}', 'validation'), suffix='.npz', join=False)\n            if len(files_here) == 0:\n                raise RuntimeError(f\"No .npz files found in folder {join(tr, f'fold_{f}', 'validation')}. Rerun your \"\n                                   f\"validation with the --npz flag. Use nnUNetv2_train [...] --val --npz.\")\n            files_per_folder[tr][f] = subfiles(join(tr, f'fold_{f}', 'validation'), suffix='.npz', join=False)\n            unique_filenames.update(files_per_folder[tr][f])\n\n    # verify that all trained_model_folders have all predictions\n    ok = True\n    for tr, fi in files_per_folder.items():\n        all_files_here = set()\n        for f in folds:\n            all_files_here.update(fi[f])\n        diff = unique_filenames.difference(all_files_here)\n        if len(diff) > 0:\n            ok = False\n            print(f'model {tr} does not seem to contain all predictions. Missing: {diff}')\n        if not ok:\n            raise RuntimeError('There were missing files, see print statements above this one')\n\n    # now we need to collect where these files are\n    file_mapping = []\n    for tr in list_of_trained_model_folders:\n        file_mapping.append({})\n        for f in folds:\n            for fi in files_per_folder[tr][f]:\n                # check for duplicates\n                assert fi not in file_mapping[-1].keys(), f\"Duplicate detected. Case {fi} is present in more than \" \\\n                                                          f\"one fold of model {tr}.\"\n                file_mapping[-1][fi] = join(tr, f'fold_{f}', 'validation', fi)\n\n    lists_of_lists_of_files = [[fm[i] for fm in file_mapping] for i in unique_filenames]\n    output_files_truncated = [join(output_folder, fi[:-4]) for fi in unique_filenames]\n\n    image_reader_writer = plans_manager.image_reader_writer_class()\n    maybe_mkdir_p(output_folder)\n    label_manager = plans_manager.get_label_manager(dataset_json)\n\n    if not overwrite:\n        tmp = [isfile(i + dataset_json['file_ending']) for i in output_files_truncated]\n        lists_of_lists_of_files = [lists_of_lists_of_files[i] for i in range(len(tmp)) if not tmp[i]]\n        output_files_truncated = [output_files_truncated[i] for i in range(len(tmp)) if not tmp[i]]\n\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as pool:\n        num_preds = len(lists_of_lists_of_files)\n        _ = pool.starmap(\n            merge_files,\n            zip(\n                lists_of_lists_of_files,\n                output_files_truncated,\n                [dataset_json['file_ending']] * num_preds,\n                [image_reader_writer] * num_preds,\n                [label_manager] * num_preds,\n                [False] * num_preds\n            )\n        )\n\n    shutil.copy(join(list_of_trained_model_folders[0], 'plans.json'), join(output_folder, 'plans.json'))\n    shutil.copy(join(list_of_trained_model_folders[0], 'dataset.json'), join(output_folder, 'dataset.json'))\n"
  },
  {
    "path": "nnunetv2/evaluation/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/evaluation/accumulate_cv_results.py",
    "content": "import shutil\nfrom typing import Union, List, Tuple\n\nfrom batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, maybe_mkdir_p, subfiles, isfile\n\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder\nfrom nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager\n\n\ndef accumulate_cv_results(trained_model_folder,\n                          merged_output_folder: str,\n                          folds: Union[List[int], Tuple[int, ...]],\n                          num_processes: int = default_num_processes,\n                          overwrite: bool = True):\n    \"\"\"\n    There are a lot of things that can get fucked up, so the simplest way to deal with potential problems is to\n    collect the cv results into a separate folder and then evaluate them again. No messing with summary_json files!\n    \"\"\"\n\n    if overwrite and isdir(merged_output_folder):\n        shutil.rmtree(merged_output_folder)\n    maybe_mkdir_p(merged_output_folder)\n\n    dataset_json = load_json(join(trained_model_folder, 'dataset.json'))\n    plans_manager = PlansManager(join(trained_model_folder, 'plans.json'))\n    rw = plans_manager.image_reader_writer_class()\n    shutil.copy(join(trained_model_folder, 'dataset.json'), join(merged_output_folder, 'dataset.json'))\n    shutil.copy(join(trained_model_folder, 'plans.json'), join(merged_output_folder, 'plans.json'))\n\n    did_we_copy_something = False\n    for f in folds:\n        expected_validation_folder = join(trained_model_folder, f'fold_{f}', 'validation')\n        if not isdir(expected_validation_folder):\n            raise RuntimeError(f\"fold {f} of model {trained_model_folder} is missing. Please train it!\")\n        predicted_files = subfiles(expected_validation_folder, suffix=dataset_json['file_ending'], join=False)\n        for pf in predicted_files:\n            if overwrite and isfile(join(merged_output_folder, pf)):\n                raise RuntimeError(f'More than one of your folds has a prediction for case {pf}')\n            if overwrite or not isfile(join(merged_output_folder, pf)):\n                shutil.copy(join(expected_validation_folder, pf), join(merged_output_folder, pf))\n                did_we_copy_something = True\n\n    if did_we_copy_something or not isfile(join(merged_output_folder, 'summary.json')):\n        label_manager = plans_manager.get_label_manager(dataset_json)\n        gt_folder = join(nnUNet_raw, plans_manager.dataset_name, 'labelsTr')\n        if not isdir(gt_folder):\n            gt_folder = join(nnUNet_preprocessed, plans_manager.dataset_name, 'gt_segmentations')\n        compute_metrics_on_folder(gt_folder,\n                                  merged_output_folder,\n                                  join(merged_output_folder, 'summary.json'),\n                                  rw,\n                                  dataset_json['file_ending'],\n                                  label_manager.foreground_regions if label_manager.has_regions else\n                                  label_manager.foreground_labels,\n                                  label_manager.ignore_label,\n                                  num_processes)\n"
  },
  {
    "path": "nnunetv2/evaluation/evaluate_predictions.py",
    "content": "import multiprocessing\nimport os\nfrom copy import deepcopy\nfrom typing import Tuple, List, Union\n\nimport numpy as np\nfrom batchgenerators.utilities.file_and_folder_operations import subfiles, join, save_json, load_json, \\\n    isfile\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.imageio.base_reader_writer import BaseReaderWriter\nfrom nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json, \\\n    determine_reader_writer_from_file_ending\nfrom nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO\n# the Evaluator class of the previous nnU-Net was great and all but man was it overengineered. Keep it simple\nfrom nnunetv2.utilities.json_export import recursive_fix_for_json_export\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager\n\n\ndef label_or_region_to_key(label_or_region: Union[int, Tuple[int]]):\n    return str(label_or_region)\n\n\ndef key_to_label_or_region(key: str):\n    try:\n        return int(key)\n    except ValueError:\n        key = key.replace('(', '')\n        key = key.replace(')', '')\n        split = key.split(',')\n        return tuple([int(i) for i in split if len(i) > 0])\n\n\ndef save_summary_json(results: dict, output_file: str):\n    \"\"\"\n    json does not support tuples as keys (why does it have to be so shitty) so we need to convert that shit\n    ourselves\n    \"\"\"\n    results_converted = deepcopy(results)\n    # convert keys in mean metrics\n    results_converted['mean'] = {label_or_region_to_key(k): results['mean'][k] for k in results['mean'].keys()}\n    # convert metric_per_case\n    for i in range(len(results_converted[\"metric_per_case\"])):\n        results_converted[\"metric_per_case\"][i]['metrics'] = \\\n            {label_or_region_to_key(k): results[\"metric_per_case\"][i]['metrics'][k]\n             for k in results[\"metric_per_case\"][i]['metrics'].keys()}\n    # sort_keys=True will make foreground_mean the first entry and thus easy to spot\n    save_json(results_converted, output_file, sort_keys=True)\n\n\ndef load_summary_json(filename: str):\n    results = load_json(filename)\n    # convert keys in mean metrics\n    results['mean'] = {key_to_label_or_region(k): results['mean'][k] for k in results['mean'].keys()}\n    # convert metric_per_case\n    for i in range(len(results[\"metric_per_case\"])):\n        results[\"metric_per_case\"][i]['metrics'] = \\\n            {key_to_label_or_region(k): results[\"metric_per_case\"][i]['metrics'][k]\n             for k in results[\"metric_per_case\"][i]['metrics'].keys()}\n    return results\n\n\ndef labels_to_list_of_regions(labels: List[int]):\n    return [(i,) for i in labels]\n\n\ndef region_or_label_to_mask(segmentation: np.ndarray, region_or_label: Union[int, Tuple[int, ...]]) -> np.ndarray:\n    if np.isscalar(region_or_label):\n        return segmentation == region_or_label\n    else:\n        mask = np.zeros_like(segmentation, dtype=bool)\n        for r in region_or_label:\n            mask[segmentation == r] = True\n    return mask\n\n\ndef compute_tp_fp_fn_tn(mask_ref: np.ndarray, mask_pred: np.ndarray, ignore_mask: np.ndarray = None):\n    if ignore_mask is None:\n        use_mask = np.ones_like(mask_ref, dtype=bool)\n    else:\n        use_mask = ~ignore_mask\n    tp = np.sum((mask_ref & mask_pred) & use_mask)\n    fp = np.sum(((~mask_ref) & mask_pred) & use_mask)\n    fn = np.sum((mask_ref & (~mask_pred)) & use_mask)\n    tn = np.sum(((~mask_ref) & (~mask_pred)) & use_mask)\n    return tp, fp, fn, tn\n\n\ndef compute_metrics(reference_file: str, prediction_file: str, image_reader_writer: BaseReaderWriter,\n                    labels_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]],\n                    ignore_label: int = None) -> dict:\n    # load images\n    seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file)\n    seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file)\n\n    ignore_mask = seg_ref == ignore_label if ignore_label is not None else None\n\n    results = {}\n    results['reference_file'] = reference_file\n    results['prediction_file'] = prediction_file\n    results['metrics'] = {}\n    for r in labels_or_regions:\n        results['metrics'][r] = {}\n        mask_ref = region_or_label_to_mask(seg_ref, r)\n        mask_pred = region_or_label_to_mask(seg_pred, r)\n        tp, fp, fn, tn = compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask)\n        if tp + fp + fn == 0:\n            results['metrics'][r]['Dice'] = np.nan\n            results['metrics'][r]['IoU'] = np.nan\n        else:\n            results['metrics'][r]['Dice'] = 2 * tp / (2 * tp + fp + fn)\n            results['metrics'][r]['IoU'] = tp / (tp + fp + fn)\n        results['metrics'][r]['FP'] = fp\n        results['metrics'][r]['TP'] = tp\n        results['metrics'][r]['FN'] = fn\n        results['metrics'][r]['TN'] = tn\n        results['metrics'][r]['n_pred'] = fp + tp\n        results['metrics'][r]['n_ref'] = fn + tp\n    return results\n\n\ndef compute_metrics_on_folder(folder_ref: str, folder_pred: str, output_file: str,\n                              image_reader_writer: BaseReaderWriter,\n                              file_ending: str,\n                              regions_or_labels: Union[List[int], List[Union[int, Tuple[int, ...]]]],\n                              ignore_label: int = None,\n                              num_processes: int = default_num_processes,\n                              chill: bool = True) -> dict:\n    \"\"\"\n    output_file must end with .json; can be None\n    \"\"\"\n    if output_file is not None:\n        assert output_file.endswith('.json'), 'output_file should end with .json'\n    files_pred = subfiles(folder_pred, suffix=file_ending, join=False)\n    files_ref = subfiles(folder_ref, suffix=file_ending, join=False)\n    if not chill:\n        present = [isfile(join(folder_pred, i)) for i in files_ref]\n        assert all(present), \"Not all files in folder_ref exist in folder_pred\"\n    files_ref = [join(folder_ref, i) for i in files_pred]\n    files_pred = [join(folder_pred, i) for i in files_pred]\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as pool:\n        # for i in list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), [ignore_label] * len(files_pred))):\n        #     compute_metrics(*i)\n        results = pool.starmap(\n            compute_metrics,\n            list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred),\n                     [ignore_label] * len(files_pred)))\n        )\n\n    # mean metric per class\n    metric_list = list(results[0]['metrics'][regions_or_labels[0]].keys())\n    means = {}\n    for r in regions_or_labels:\n        means[r] = {}\n        for m in metric_list:\n            means[r][m] = np.nanmean([i['metrics'][r][m] for i in results])\n\n    # foreground mean\n    foreground_mean = {}\n    for m in metric_list:\n        values = []\n        for k in means.keys():\n            if k == 0 or k == '0':\n                continue\n            values.append(means[k][m])\n        foreground_mean[m] = np.mean(values)\n\n    [recursive_fix_for_json_export(i) for i in results]\n    recursive_fix_for_json_export(means)\n    recursive_fix_for_json_export(foreground_mean)\n    result = {'metric_per_case': results, 'mean': means, 'foreground_mean': foreground_mean}\n    if output_file is not None:\n        save_summary_json(result, output_file)\n    return result\n    # print('DONE')\n\n\ndef compute_metrics_on_folder2(folder_ref: str, folder_pred: str, dataset_json_file: str, plans_file: str,\n                               output_file: str = None,\n                               num_processes: int = default_num_processes,\n                               chill: bool = False):\n    dataset_json = load_json(dataset_json_file)\n    # get file ending\n    file_ending = dataset_json['file_ending']\n\n    # get reader writer class\n    example_file = subfiles(folder_ref, suffix=file_ending, join=True)[0]\n    rw = determine_reader_writer_from_dataset_json(dataset_json, example_file)()\n\n    # maybe auto set output file\n    if output_file is None:\n        output_file = join(folder_pred, 'summary.json')\n\n    lm = PlansManager(plans_file).get_label_manager(dataset_json)\n    compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending,\n                              lm.foreground_regions if lm.has_regions else lm.foreground_labels, lm.ignore_label,\n                              num_processes, chill=chill)\n\n\ndef compute_metrics_on_folder_simple(folder_ref: str, folder_pred: str, labels: Union[Tuple[int, ...], List[int]],\n                                     output_file: str = None,\n                                     num_processes: int = default_num_processes,\n                                     ignore_label: int = None,\n                                     chill: bool = False):\n    example_file = subfiles(folder_ref, join=True)[0]\n    file_ending = os.path.splitext(example_file)[-1]\n    rw = determine_reader_writer_from_file_ending(file_ending, example_file, allow_nonmatching_filename=True,\n                                                  verbose=False)()\n    # maybe auto set output file\n    if output_file is None:\n        output_file = join(folder_pred, 'summary.json')\n    compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending,\n                              labels, ignore_label=ignore_label, num_processes=num_processes, chill=chill)\n\n\ndef evaluate_folder_entry_point():\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('gt_folder', type=str, help='folder with gt segmentations')\n    parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations')\n    parser.add_argument('-djfile', type=str, required=True,\n                        help='dataset.json file')\n    parser.add_argument('-pfile', type=str, required=True,\n                        help='plans.json file')\n    parser.add_argument('-o', type=str, required=False, default=None,\n                        help='Output file. Optional. Default: pred_folder/summary.json')\n    parser.add_argument('-np', type=int, required=False, default=default_num_processes,\n                        help=f'number of processes used. Optional. Default: {default_num_processes}')\n    parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt')\n    args = parser.parse_args()\n    compute_metrics_on_folder2(args.gt_folder, args.pred_folder, args.djfile, args.pfile, args.o, args.np, chill=args.chill)\n\n\ndef evaluate_simple_entry_point():\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('gt_folder', type=str, help='folder with gt segmentations')\n    parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations')\n    parser.add_argument('-l', type=int, nargs='+', required=True,\n                        help='list of labels')\n    parser.add_argument('-il', type=int, required=False, default=None,\n                        help='ignore label')\n    parser.add_argument('-o', type=str, required=False, default=None,\n                        help='Output file. Optional. Default: pred_folder/summary.json')\n    parser.add_argument('-np', type=int, required=False, default=default_num_processes,\n                        help=f'number of processes used. Optional. Default: {default_num_processes}')\n    parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt')\n\n    args = parser.parse_args()\n    compute_metrics_on_folder_simple(args.gt_folder, args.pred_folder, args.l, args.o, args.np, args.il, chill=args.chill)\n\n\nif __name__ == '__main__':\n    folder_ref = '/media/fabian/data/nnUNet_raw/Dataset004_Hippocampus/labelsTr'\n    folder_pred = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation'\n    output_file = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation/summary.json'\n    image_reader_writer = SimpleITKIO()\n    file_ending = '.nii.gz'\n    regions = labels_to_list_of_regions([1, 2])\n    ignore_label = None\n    num_processes = 12\n    compute_metrics_on_folder(folder_ref, folder_pred, output_file, image_reader_writer, file_ending, regions, ignore_label,\n                              num_processes)\n"
  },
  {
    "path": "nnunetv2/evaluation/find_best_configuration.py",
    "content": "import argparse\nimport os.path\nfrom copy import deepcopy\nfrom typing import Union, List, Tuple\n\nfrom batchgenerators.utilities.file_and_folder_operations import (\n    load_json, join, isdir, listdir, save_json\n)\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.ensembling.ensemble import ensemble_crossvalidations\nfrom nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results\nfrom nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder, load_summary_json\nfrom nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results\nfrom nnunetv2.postprocessing.remove_connected_components import determine_postprocessing\nfrom nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name, get_output_folder, \\\n    convert_identifier_to_trainer_plans_config, get_ensemble_name, folds_tuple_to_string\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager\n\ndefault_trained_models = tuple([\n    {'plans': 'nnUNetPlans', 'configuration': '2d', 'trainer': 'nnUNetTrainer'},\n    {'plans': 'nnUNetPlans', 'configuration': '3d_fullres', 'trainer': 'nnUNetTrainer'},\n    {'plans': 'nnUNetPlans', 'configuration': '3d_lowres', 'trainer': 'nnUNetTrainer'},\n    {'plans': 'nnUNetPlans', 'configuration': '3d_cascade_fullres', 'trainer': 'nnUNetTrainer'},\n])\n\n\ndef filter_available_models(model_dict: Union[List[dict], Tuple[dict, ...]], dataset_name_or_id: Union[str, int]):\n    valid = []\n    for trained_model in model_dict:\n        plans_manager = PlansManager(join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id),\n                               trained_model['plans'] + '.json'))\n        # check if configuration exists\n        # 3d_cascade_fullres and 3d_lowres do not exist for each dataset so we allow them to be absent IF they are not\n        # specified in the plans file\n        if trained_model['configuration'] not in plans_manager.available_configurations:\n            print(f\"Configuration {trained_model['configuration']} not found in plans {trained_model['plans']}.\\n\"\n                  f\"Inferred plans file: {join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id), trained_model['plans'] + '.json')}.\")\n            continue\n\n        # check if trained model output folder exists. This is a requirement. No mercy here.\n        expected_output_folder = get_output_folder(dataset_name_or_id, trained_model['trainer'], trained_model['plans'],\n                                                   trained_model['configuration'], fold=None)\n        if not isdir(expected_output_folder):\n            raise RuntimeError(f\"Trained model {trained_model} does not have an output folder. \"\n                  f\"Expected: {expected_output_folder}. Please run the training for this model! (don't forget \"\n                  f\"the --npz flag if you want to ensemble multiple configurations)\")\n\n        valid.append(trained_model)\n    return valid\n\n\ndef generate_inference_command(dataset_name_or_id: Union[int, str], configuration_name: str,\n                               plans_identifier: str = 'nnUNetPlans', trainer_name: str = 'nnUNetTrainer',\n                               folds: Union[List[int], Tuple[int, ...]] = (0, 1, 2, 3, 4),\n                               folder_with_segs_from_prev_stage: str = None,\n                               input_folder: str = 'INPUT_FOLDER',\n                               output_folder: str = 'OUTPUT_FOLDER',\n                               save_npz: bool = False):\n    fold_str = ''\n    for f in folds:\n        fold_str += f' {f}'\n\n    predict_command = ''\n    trained_model_folder = get_output_folder(dataset_name_or_id, trainer_name, plans_identifier, configuration_name, fold=None)\n    plans_manager = PlansManager(join(trained_model_folder, 'plans.json'))\n    configuration_manager = plans_manager.get_configuration(configuration_name)\n    if 'previous_stage' in plans_manager.available_configurations:\n        prev_stage = configuration_manager.previous_stage_name\n        predict_command += generate_inference_command(dataset_name_or_id, prev_stage, plans_identifier, trainer_name,\n                                                      folds, None, output_folder='OUTPUT_FOLDER_PREV_STAGE') + '\\n'\n        folder_with_segs_from_prev_stage = 'OUTPUT_FOLDER_PREV_STAGE'\n\n    predict_command = f'nnUNetv2_predict -d {dataset_name_or_id} -i {input_folder} -o {output_folder} -f {fold_str} ' \\\n                      f'-tr {trainer_name} -c {configuration_name} -p {plans_identifier}'\n    if folder_with_segs_from_prev_stage is not None:\n        predict_command += f' -prev_stage_predictions {folder_with_segs_from_prev_stage}'\n    if save_npz:\n        predict_command += ' --save_probabilities'\n    return predict_command\n\n\ndef find_best_configuration(dataset_name_or_id,\n                            allowed_trained_models: Union[List[dict], Tuple[dict, ...]] = default_trained_models,\n                            allow_ensembling: bool = True,\n                            num_processes: int = default_num_processes,\n                            overwrite: bool = True,\n                            folds: Union[List[int], Tuple[int, ...]] = (0, 1, 2, 3, 4),\n                            strict: bool = False):\n    dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)\n    all_results = {}\n\n    allowed_trained_models = filter_available_models(deepcopy(allowed_trained_models), dataset_name_or_id)\n\n    for m in allowed_trained_models:\n        output_folder = get_output_folder(dataset_name_or_id, m['trainer'], m['plans'], m['configuration'], fold=None)\n        if not isdir(output_folder) and strict:\n            raise RuntimeError(f'{dataset_name}: The output folder of plans {m[\"plans\"]} configuration '\n                               f'{m[\"configuration\"]} is missing. Please train the model (all requested folds!) first!')\n        identifier = os.path.basename(output_folder)\n        merged_output_folder = join(output_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}')\n        accumulate_cv_results(output_folder, merged_output_folder, folds, num_processes, overwrite)\n        all_results[identifier] = {\n            'source': merged_output_folder,\n            'result': load_summary_json(join(merged_output_folder, 'summary.json'))['foreground_mean']['Dice']\n        }\n\n    if allow_ensembling:\n        for i in range(len(allowed_trained_models)):\n            for j in range(i + 1, len(allowed_trained_models)):\n                m1, m2 = allowed_trained_models[i], allowed_trained_models[j]\n\n                output_folder_1 = get_output_folder(dataset_name_or_id, m1['trainer'], m1['plans'], m1['configuration'], fold=None)\n                output_folder_2 = get_output_folder(dataset_name_or_id, m2['trainer'], m2['plans'], m2['configuration'], fold=None)\n                identifier = get_ensemble_name(output_folder_1, output_folder_2, folds)\n\n                output_folder_ensemble = join(nnUNet_results, dataset_name, 'ensembles', identifier)\n\n                ensemble_crossvalidations([output_folder_1, output_folder_2], output_folder_ensemble, folds,\n                                          num_processes, overwrite=overwrite)\n\n                # evaluate ensembled predictions\n                plans_manager = PlansManager(join(output_folder_1, 'plans.json'))\n                dataset_json = load_json(join(output_folder_1, 'dataset.json'))\n                label_manager = plans_manager.get_label_manager(dataset_json)\n                rw = plans_manager.image_reader_writer_class()\n\n                compute_metrics_on_folder(join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'),\n                                          output_folder_ensemble,\n                                          join(output_folder_ensemble, 'summary.json'),\n                                          rw,\n                                          dataset_json['file_ending'],\n                                          label_manager.foreground_regions if label_manager.has_regions else\n                                          label_manager.foreground_labels,\n                                          label_manager.ignore_label,\n                                          num_processes)\n                all_results[identifier] = \\\n                    {\n                    'source': output_folder_ensemble,\n                    'result': load_summary_json(join(output_folder_ensemble, 'summary.json'))['foreground_mean']['Dice']\n                    }\n\n    # pick best and report inference command\n    best_score = max([i['result'] for i in all_results.values()])\n    best_keys = [k for k in all_results.keys() if all_results[k]['result'] == best_score]  # may never happen but theoretically\n    # there can be a tie. Let's pick the first model in this case because it's going to be the simpler one (ensembles\n    # come after single configs)\n    best_key = best_keys[0]\n\n    print()\n    print('***All results:***')\n    for k, v in all_results.items():\n        print(f'{k}: {v[\"result\"]}')\n    print(f'\\n*Best*: {best_key}: {all_results[best_key][\"result\"]}')\n    print()\n\n    print('***Determining postprocessing for best model/ensemble***')\n    determine_postprocessing(all_results[best_key]['source'], join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'),\n                             plans_file_or_dict=join(all_results[best_key]['source'], 'plans.json'),\n                             dataset_json_file_or_dict=join(all_results[best_key]['source'], 'dataset.json'),\n                             num_processes=num_processes, keep_postprocessed_files=True)\n\n    # in addition to just reading the console output (how it was previously) we should return the information\n    # needed to run the full inference via API\n    return_dict = {\n        'folds': folds,\n        'dataset_name_or_id': dataset_name_or_id,\n        'considered_models': allowed_trained_models,\n        'ensembling_allowed': allow_ensembling,\n        'all_results': {i: j['result'] for i, j in all_results.items()},\n        'best_model_or_ensemble': {\n            'result_on_crossval_pre_pp': all_results[best_key][\"result\"],\n            'result_on_crossval_post_pp': load_json(join(all_results[best_key]['source'], 'postprocessed', 'summary.json'))['foreground_mean']['Dice'],\n            'postprocessing_file': join(all_results[best_key]['source'], 'postprocessing.pkl'),\n            'some_plans_file': join(all_results[best_key]['source'], 'plans.json'),\n            # just needed for label handling, can\n            # come from any of the ensemble members (if any)\n            'selected_model_or_models': []\n        }\n    }\n    # convert best key to inference command:\n    if best_key.startswith('ensemble___'):\n        prefix, m1, m2, folds_string = best_key.split('___')\n        tr1, pl1, c1 = convert_identifier_to_trainer_plans_config(m1)\n        tr2, pl2, c2 = convert_identifier_to_trainer_plans_config(m2)\n        return_dict['best_model_or_ensemble']['selected_model_or_models'].append(\n            {\n                'configuration': c1,\n                'trainer': tr1,\n                'plans_identifier': pl1,\n            })\n        return_dict['best_model_or_ensemble']['selected_model_or_models'].append(\n            {\n                'configuration': c2,\n                'trainer': tr2,\n                'plans_identifier': pl2,\n            })\n    else:\n        tr, pl, c = convert_identifier_to_trainer_plans_config(best_key)\n        return_dict['best_model_or_ensemble']['selected_model_or_models'].append(\n            {\n                'configuration': c,\n                'trainer': tr,\n                'plans_identifier': pl,\n            })\n\n    save_json(return_dict, join(nnUNet_results, dataset_name, 'inference_information.json'))  # save this so that we don't have to run this\n    # everything someone wants to be reminded of the inference commands. They can just load this and give it to\n    # print_inference_instructions\n\n    # print it\n    print_inference_instructions(return_dict, instructions_file=join(nnUNet_results, dataset_name, 'inference_instructions.txt'))\n    return return_dict\n\n\ndef print_inference_instructions(inference_info_dict: dict, instructions_file: str = None):\n    def _print_and_maybe_write_to_file(string):\n        print(string)\n        if f_handle is not None:\n            f_handle.write(f'{string}\\n')\n\n    f_handle = open(instructions_file, 'w') if instructions_file is not None else None\n    print()\n    _print_and_maybe_write_to_file('***Run inference like this:***\\n')\n    output_folders = []\n\n    dataset_name_or_id = inference_info_dict['dataset_name_or_id']\n    if len(inference_info_dict['best_model_or_ensemble']['selected_model_or_models']) > 1:\n        is_ensemble = True\n        _print_and_maybe_write_to_file('An ensemble won! What a surprise! Run the following commands to run predictions with the ensemble members:\\n')\n    else:\n        is_ensemble = False\n\n    for j, i in enumerate(inference_info_dict['best_model_or_ensemble']['selected_model_or_models']):\n        tr, c, pl = i['trainer'], i['configuration'], i['plans_identifier']\n        if is_ensemble:\n            output_folder_name = f\"OUTPUT_FOLDER_MODEL_{j+1}\"\n        else:\n            output_folder_name = f\"OUTPUT_FOLDER\"\n        output_folders.append(output_folder_name)\n\n        _print_and_maybe_write_to_file(generate_inference_command(dataset_name_or_id, c, pl, tr, inference_info_dict['folds'],\n                                         save_npz=is_ensemble, output_folder=output_folder_name))\n\n    if is_ensemble:\n        output_folder_str = output_folders[0]\n        for o in output_folders[1:]:\n            output_folder_str += f' {o}'\n        output_ensemble = f\"OUTPUT_FOLDER\"\n        _print_and_maybe_write_to_file('\\nThe run ensembling with:\\n')\n        _print_and_maybe_write_to_file(f\"nnUNetv2_ensemble -i {output_folder_str} -o {output_ensemble} -np {default_num_processes}\")\n\n    _print_and_maybe_write_to_file(\"\\n***Once inference is completed, run postprocessing like this:***\\n\")\n    _print_and_maybe_write_to_file(f\"nnUNetv2_apply_postprocessing -i OUTPUT_FOLDER -o OUTPUT_FOLDER_PP \"\n          f\"-pp_pkl_file {inference_info_dict['best_model_or_ensemble']['postprocessing_file']} -np {default_num_processes} \"\n          f\"-plans_json {inference_info_dict['best_model_or_ensemble']['some_plans_file']}\")\n\n\ndef dumb_trainer_config_plans_to_trained_models_dict(trainers: List[str], configs: List[str], plans: List[str]):\n    \"\"\"\n    function is called dumb because it's dumb\n    \"\"\"\n    ret = []\n    for t in trainers:\n        for c in configs:\n            for p in plans:\n                ret.append(\n                    {'plans': p, 'configuration': c, 'trainer': t}\n                )\n    return tuple(ret)\n\n\ndef find_best_configuration_entry_point():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('dataset_name_or_id', type=str, help='Dataset Name or id')\n    parser.add_argument('-p', nargs='+', required=False, default=['nnUNetPlans'],\n                        help='List of plan identifiers. Default: nnUNetPlans')\n    parser.add_argument('-c', nargs='+', required=False, default=['2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres'],\n                        help=\"List of configurations. Default: ['2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres']\")\n    parser.add_argument('-tr', nargs='+', required=False, default=['nnUNetTrainer'],\n                        help='List of trainers. Default: nnUNetTrainer')\n    parser.add_argument('-np', required=False, default=default_num_processes, type=int,\n                        help='Number of processes to use for ensembling, postprocessing etc')\n    parser.add_argument('-f', nargs='+', type=int, default=(0, 1, 2, 3, 4),\n                        help='Folds to use. Default: 0 1 2 3 4')\n    parser.add_argument('--disable_ensembling', action='store_true', required=False,\n                        help='Set this flag to disable ensembling')\n    parser.add_argument('--no_overwrite', action='store_true',\n                        help='If set we will not overwrite already ensembled files etc. May speed up consecutive '\n                             'runs of this command (why would you want to do that?) at the risk of not updating '\n                             'outdated results.')\n    args = parser.parse_args()\n\n    model_dict = dumb_trainer_config_plans_to_trained_models_dict(args.tr, args.c, args.p)\n    dataset_name = maybe_convert_to_dataset_name(args.dataset_name_or_id)\n\n    find_best_configuration(dataset_name, model_dict, allow_ensembling=not args.disable_ensembling,\n                            num_processes=args.np, overwrite=not args.no_overwrite, folds=args.f,\n                            strict=False)\n\n\ndef accumulate_crossval_results_entry_point():\n    parser = argparse.ArgumentParser('Copies all predicted segmentations from the individual folds into one joint '\n                                     'folder and evaluates them')\n    parser.add_argument('dataset_name_or_id', type=str, help='Dataset Name or id')\n    parser.add_argument('-c', type=str, required=True,\n                        default='3d_fullres',\n                        help=\"Configuration\")\n    parser.add_argument('-o', type=str, required=False, default=None,\n                        help=\"Output folder. If not specified, the output folder will be located in the trained \" \\\n                             \"model directory (named crossval_results_folds_XXX).\")\n    parser.add_argument('-f', nargs='+', type=int, default=(0, 1, 2, 3, 4),\n                        help='Folds to use. Default: 0 1 2 3 4')\n    parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',\n                        help='Plan identifier in which to search for the specified configuration. Default: nnUNetPlans')\n    parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',\n                        help='Trainer class. Default: nnUNetTrainer')\n    args = parser.parse_args()\n    trained_model_folder = get_output_folder(args.dataset_name_or_id, args.tr, args.p, args.c)\n\n    if args.o is None:\n        merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(args.f)}')\n    else:\n        merged_output_folder = args.o\n        if isdir(merged_output_folder) and len(listdir(merged_output_folder)) > 0:\n            raise FileExistsError(\n                f\"Output folder {merged_output_folder} exists and is not empty. \"\n                f\"To avoid data loss, nnUNet requires an empty output folder.\"\n            )\n\n    accumulate_cv_results(trained_model_folder, merged_output_folder, args.f)\n\n\nif __name__ == '__main__':\n    find_best_configuration(4,\n                            default_trained_models,\n                            True,\n                            8,\n                            False,\n                            (0, 1, 2, 3, 4))\n"
  },
  {
    "path": "nnunetv2/experiment_planning/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/experiment_planning/dataset_fingerprint/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/experiment_planning/dataset_fingerprint/fingerprint_extractor.py",
    "content": "import multiprocessing\nimport os\nfrom time import sleep\nfrom typing import List, Type, Union\n\nimport numpy as np\nfrom batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p\nfrom tqdm import tqdm\n\nfrom nnunetv2.imageio.base_reader_writer import BaseReaderWriter\nfrom nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json\nfrom nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed\nfrom nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\nfrom nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets\n\n\nclass DatasetFingerprintExtractor(object):\n    def __init__(self, dataset_name_or_id: Union[str, int], num_processes: int = 8, verbose: bool = False):\n        \"\"\"\n        extracts the dataset fingerprint used for experiment planning. The dataset fingerprint will be saved as a\n        json file in the input_folder\n\n        Philosophy here is to do only what we really need. Don't store stuff that we can easily read from somewhere\n        else. Don't compute stuff we don't need (except for intensity_statistics_per_channel)\n        \"\"\"\n        dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)\n        self.verbose = verbose\n        self.show_progress_bar = True\n\n        self.dataset_name = dataset_name\n        self.input_folder = join(nnUNet_raw, dataset_name)\n        self.num_processes = num_processes\n        self.dataset_json = load_json(join(self.input_folder, 'dataset.json'))\n        self.dataset = get_filenames_of_train_images_and_targets(self.input_folder, self.dataset_json)\n\n        # We don't want to use all foreground voxels because that can accumulate a lot of data (out of memory). It is\n        # also not critically important to get all pixels as long as there are enough. Let's use 10e7 voxels in total\n        # (for the entire dataset)\n        self.num_foreground_voxels_for_intensitystats = 10e7\n\n    @staticmethod\n    def collect_foreground_intensities(segmentation: np.ndarray, images: np.ndarray, seed: int = 1234,\n                                       num_samples: int = 10000):\n        \"\"\"\n        images=image with multiple channels = shape (c, x, y(, z))\n        \"\"\"\n        assert images.ndim == 4 and segmentation.ndim == 4\n        assert not np.any(np.isnan(segmentation)), \"Segmentation contains NaN values. grrrr.... :-(\"\n        assert not np.any(np.isnan(images)), \"Images contains NaN values. grrrr.... :-(\"\n\n        rs = np.random.RandomState(seed)\n\n        intensities_per_channel = []\n        # we don't use the intensity_statistics_per_channel at all, it's just something that might be nice to have\n        intensity_statistics_per_channel = []\n\n        # segmentation is 4d: 1,x,y,z. We need to remove the empty dimension for the following code to work\n        foreground_mask = segmentation[0] > 0\n        percentiles = np.array((0.5, 50.0, 99.5))\n\n        for i in range(len(images)):\n            foreground_pixels = images[i][foreground_mask]\n            num_fg = len(foreground_pixels)\n            # sample with replacement so that we don't get issues with cases that have less than num_samples\n            # foreground_pixels. We could also just sample less in those cases but that would than cause these\n            # training cases to be underrepresented\n            intensities_per_channel.append(\n                rs.choice(foreground_pixels, num_samples, replace=True) if num_fg > 0 else [])\n\n            mean, median, mini, maxi, percentile_99_5, percentile_00_5 = np.nan, np.nan, np.nan, np.nan, np.nan, np.nan\n            if num_fg > 0:\n                percentile_00_5, median, percentile_99_5 = np.percentile(foreground_pixels, percentiles)\n                mean = np.mean(foreground_pixels)\n                mini = np.min(foreground_pixels)\n                maxi = np.max(foreground_pixels)\n\n            intensity_statistics_per_channel.append({\n                'mean': mean,\n                'median': median,\n                'min': mini,\n                'max': maxi,\n                'percentile_99_5': percentile_99_5,\n                'percentile_00_5': percentile_00_5,\n\n            })\n\n        return intensities_per_channel, intensity_statistics_per_channel\n\n    @staticmethod\n    def analyze_case(image_files: List[str], segmentation_file: str, reader_writer_class: Type[BaseReaderWriter],\n                     num_samples: int = 10000):\n        rw = reader_writer_class()\n        images, properties_images = rw.read_images(image_files)\n        segmentation, properties_seg = rw.read_seg(segmentation_file)\n\n        # we no longer crop and save the cropped images before this is run. Instead we run the cropping on the fly.\n        # Downside is that we need to do this twice (once here and once during preprocessing). Upside is that we don't\n        # need to save the cropped data anymore. Given that cropping is not too expensive it makes sense to do it this\n        # way. This is only possible because we are now using our new input/output interface.\n        data_cropped, seg_cropped, bbox = crop_to_nonzero(images, segmentation)\n\n        foreground_intensities_per_channel, foreground_intensity_stats_per_channel = \\\n            DatasetFingerprintExtractor.collect_foreground_intensities(seg_cropped, data_cropped,\n                                                                       num_samples=num_samples)\n\n        spacing = properties_images['spacing']\n\n        shape_before_crop = images.shape[1:]\n        shape_after_crop = data_cropped.shape[1:]\n        relative_size_after_cropping = np.prod(shape_after_crop) / np.prod(shape_before_crop)\n        return shape_after_crop, spacing, foreground_intensities_per_channel, foreground_intensity_stats_per_channel, \\\n               relative_size_after_cropping\n\n    def run(self, overwrite_existing: bool = False) -> dict:\n        # we do not save the properties file in self.input_folder because that folder might be read-only. We can only\n        # reliably write in nnUNet_preprocessed and nnUNet_results, so nnUNet_preprocessed it is\n        preprocessed_output_folder = join(nnUNet_preprocessed, self.dataset_name)\n        maybe_mkdir_p(preprocessed_output_folder)\n        properties_file = join(preprocessed_output_folder, 'dataset_fingerprint.json')\n\n        if not isfile(properties_file) or overwrite_existing:\n            reader_writer_class = determine_reader_writer_from_dataset_json(self.dataset_json,\n                                                                            # yikes. Rip the following line\n                                                                            self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0])\n\n            # determine how many foreground voxels we need to sample per training case\n            num_foreground_samples_per_case = int(self.num_foreground_voxels_for_intensitystats //\n                                                  len(self.dataset))\n\n            r = []\n            with multiprocessing.get_context(\"spawn\").Pool(self.num_processes) as p:\n                for k in self.dataset.keys():\n                    r.append(p.starmap_async(DatasetFingerprintExtractor.analyze_case,\n                                             ((self.dataset[k]['images'], self.dataset[k]['label'], reader_writer_class,\n                                               num_foreground_samples_per_case),)))\n                remaining = list(range(len(self.dataset)))\n                # p is pretty nifti. If we kill workers they just respawn but don't do any work.\n                # So we need to store the original pool of workers.\n                workers = [j for j in p._pool]\n                with tqdm(desc=\"Extracting dataset fingerprint\", total=len(self.dataset),\n                          disable=not getattr(self, 'show_progress_bar', True)) as pbar:\n                    while len(remaining) > 0:\n                        all_alive = all([j.is_alive() for j in workers])\n                        if not all_alive:\n                            raise RuntimeError('Some background worker is 6 feet under. Yuck. \\n'\n                                               'OK jokes aside.\\n'\n                                               'One of your background processes is missing. This could be because of '\n                                               'an error (look for an error message) or because it was killed '\n                                               'by your OS due to running out of RAM. If you don\\'t see '\n                                               'an error message, out of RAM is likely the problem. In that case '\n                                               'reducing the number of workers might help')\n                        done = [i for i in remaining if r[i].ready()]\n                        for _ in done:\n                            pbar.update()\n                        remaining = [i for i in remaining if i not in done]\n                        sleep(0.1)\n\n            # results = ptqdm(DatasetFingerprintExtractor.analyze_case,\n            #                 (training_images_per_case, training_labels_per_case),\n            #                 processes=self.num_processes, zipped=True, reader_writer_class=reader_writer_class,\n            #                 num_samples=num_foreground_samples_per_case, disable=self.verbose)\n            results = [i.get()[0] for i in r]\n\n            shapes_after_crop = [r[0] for r in results]\n            spacings = [r[1] for r in results]\n            foreground_intensities_per_channel = [np.concatenate([r[2][i] for r in results]) for i in\n                                                  range(len(results[0][2]))]\n            foreground_intensities_per_channel = np.array(foreground_intensities_per_channel)\n            # we drop this so that the json file is somewhat human readable\n            # foreground_intensity_stats_by_case_and_modality = [r[3] for r in results]\n            median_relative_size_after_cropping = np.median([r[4] for r in results], 0)\n            num_channels = len(self.dataset_json['channel_names'].keys()\n                                 if 'channel_names' in self.dataset_json.keys()\n                                 else self.dataset_json['modality'].keys())\n            intensity_statistics_per_channel = {}\n            percentiles = np.array((0.5, 50.0, 99.5))\n            for i in range(num_channels):\n                percentile_00_5, median, percentile_99_5 = np.percentile(foreground_intensities_per_channel[i],\n                                                                         percentiles)\n                intensity_statistics_per_channel[i] = {\n                    'mean': float(np.mean(foreground_intensities_per_channel[i])),\n                    'median': float(median),\n                    'std': float(np.std(foreground_intensities_per_channel[i])),\n                    'min': float(np.min(foreground_intensities_per_channel[i])),\n                    'max': float(np.max(foreground_intensities_per_channel[i])),\n                    'percentile_99_5': float(percentile_99_5),\n                    'percentile_00_5': float(percentile_00_5),\n                }\n\n            fingerprint = {\n                    \"spacings\": spacings,\n                    \"shapes_after_crop\": shapes_after_crop,\n                    'foreground_intensity_properties_per_channel': intensity_statistics_per_channel,\n                    \"median_relative_size_after_cropping\": median_relative_size_after_cropping\n                }\n\n            try:\n                save_json(fingerprint, properties_file)\n            except Exception as e:\n                if isfile(properties_file):\n                    os.remove(properties_file)\n                raise e\n        else:\n            fingerprint = load_json(properties_file)\n        return fingerprint\n\n\nif __name__ == '__main__':\n    dfe = DatasetFingerprintExtractor(2, 8)\n    dfe.run(overwrite_existing=False)\n"
  },
  {
    "path": "nnunetv2/experiment_planning/experiment_planners/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py",
    "content": "import shutil\nfrom copy import deepcopy\nfrom typing import List, Union, Tuple\n\nimport numpy as np\nimport torch\nfrom batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p\nfrom dynamic_network_architectures.architectures.unet import PlainConvUNet\nfrom dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm\n\nfrom nnunetv2.configuration import ANISO_THRESHOLD\nfrom nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props\nfrom nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json\nfrom nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed\nfrom nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme\nfrom nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\nfrom nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA\nfrom nnunetv2.utilities.get_network_from_plans import get_network_from_plans\nfrom nnunetv2.utilities.json_export import recursive_fix_for_json_export\nfrom nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets\n\n\nclass ExperimentPlanner(object):\n    def __init__(self, dataset_name_or_id: Union[str, int],\n                 gpu_memory_target_in_gb: float = 8,\n                 preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans',\n                 overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,\n                 suppress_transpose: bool = False):\n        \"\"\"\n        overwrite_target_spacing only affects 3d_fullres! (but by extension 3d_lowres which starts with fullres may\n        also be affected\n        \"\"\"\n\n        self.dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)\n        self.suppress_transpose = suppress_transpose\n        self.raw_dataset_folder = join(nnUNet_raw, self.dataset_name)\n        preprocessed_folder = join(nnUNet_preprocessed, self.dataset_name)\n        self.dataset_json = load_json(join(self.raw_dataset_folder, 'dataset.json'))\n        self.dataset = get_filenames_of_train_images_and_targets(self.raw_dataset_folder, self.dataset_json)\n\n        # load dataset fingerprint\n        if not isfile(join(preprocessed_folder, 'dataset_fingerprint.json')):\n            raise RuntimeError('Fingerprint missing for this dataset. Please run nnUNetv2_extract_fingerprint')\n\n        self.dataset_fingerprint = load_json(join(preprocessed_folder, 'dataset_fingerprint.json'))\n\n        self.anisotropy_threshold = ANISO_THRESHOLD\n\n        self.UNet_base_num_features = 32\n        self.UNet_class = PlainConvUNet\n        # the following two numbers are really arbitrary and were set to reproduce nnU-Net v1's configurations as\n        # much as possible\n        self.UNet_reference_val_3d = 560000000  # 455600128  550000000\n        self.UNet_reference_val_2d = 85000000  # 83252480\n        self.UNet_reference_com_nfeatures = 32\n        self.UNet_reference_val_corresp_GB = 8\n        self.UNet_reference_val_corresp_bs_2d = 12\n        self.UNet_reference_val_corresp_bs_3d = 2\n        self.UNet_featuremap_min_edge_length = 4\n        self.UNet_blocks_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n        self.UNet_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n        self.UNet_min_batch_size = 2\n        self.UNet_max_features_2d = 512\n        self.UNet_max_features_3d = 320\n        self.max_dataset_covered = 0.05 # we limit the batch size so that no more than 5% of the dataset can be seen\n        # in a single forward/backward pass\n\n        self.UNet_vram_target_GB = gpu_memory_target_in_gb\n\n        self.lowres_creation_threshold = 0.25  # if the patch size of fullres is less than 25% of the voxels in the\n        # median shape then we need a lowres config as well\n\n        self.preprocessor_name = preprocessor_name\n        self.plans_identifier = plans_name\n        self.overwrite_target_spacing = overwrite_target_spacing\n        assert overwrite_target_spacing is None or len(overwrite_target_spacing), 'if overwrite_target_spacing is ' \\\n                                                                                  'used then three floats must be ' \\\n                                                                                  'given (as list or tuple)'\n        assert overwrite_target_spacing is None or all([isinstance(i, float) for i in overwrite_target_spacing]), \\\n            'if overwrite_target_spacing is used then three floats must be given (as list or tuple)'\n\n        self.plans = None\n\n        if isfile(join(self.raw_dataset_folder, 'splits_final.json')):\n            _maybe_copy_splits_file(join(self.raw_dataset_folder, 'splits_final.json'),\n                                    join(preprocessed_folder, 'splits_final.json'))\n\n    def determine_reader_writer(self):\n        example_image = self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0]\n        return determine_reader_writer_from_dataset_json(self.dataset_json, example_image)\n\n    @staticmethod\n    def static_estimate_VRAM_usage(patch_size: Tuple[int],\n                                   input_channels: int,\n                                   output_channels: int,\n                                   arch_class_name: str,\n                                   arch_kwargs: dict,\n                                   arch_kwargs_req_import: Tuple[str, ...]):\n        \"\"\"\n        Works for PlainConvUNet, ResidualEncoderUNet\n        \"\"\"\n        a = torch.get_num_threads()\n        torch.set_num_threads(get_allowed_n_proc_DA())\n        # print(f'instantiating network, patch size {patch_size}, pool op: {arch_kwargs[\"strides\"]}')\n        net = get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels,\n                                     output_channels,\n                                     allow_init=False)\n        ret = net.compute_conv_feature_map_size(patch_size)\n        torch.set_num_threads(a)\n        return ret\n\n    def determine_resampling(self, *args, **kwargs):\n        \"\"\"\n        returns what functions to use for resampling data and seg, respectively. Also returns kwargs\n        resampling function must be callable(data, current_spacing, new_spacing, **kwargs)\n\n        determine_resampling is called within get_plans_for_configuration to allow for different functions for each\n        configuration\n        \"\"\"\n        resampling_data = resample_data_or_seg_to_shape\n        resampling_data_kwargs = {\n            \"is_seg\": False,\n            \"order\": 3,\n            \"order_z\": 0,\n            \"force_separate_z\": None,\n        }\n        resampling_seg = resample_data_or_seg_to_shape\n        resampling_seg_kwargs = {\n            \"is_seg\": True,\n            \"order\": 1,\n            \"order_z\": 0,\n            \"force_separate_z\": None,\n        }\n        return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs\n\n    def determine_segmentation_softmax_export_fn(self, *args, **kwargs):\n        \"\"\"\n        function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be\n        used as target. current_spacing and new_spacing are merely there in case we want to use it somehow\n\n        determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different\n        functions for each configuration\n\n        \"\"\"\n        resampling_fn = resample_data_or_seg_to_shape\n        resampling_fn_kwargs = {\n            \"is_seg\": False,\n            \"order\": 1,\n            \"order_z\": 0,\n            \"force_separate_z\": None,\n        }\n        return resampling_fn, resampling_fn_kwargs\n\n    def determine_fullres_target_spacing(self) -> np.ndarray:\n        \"\"\"\n        per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data\n        and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training\n\n        For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic\n        (for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a spacing of 5 or 6 mm in the low\n        resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially\n        impact performance (due to the low number of slices).\n        \"\"\"\n        if self.overwrite_target_spacing is not None:\n            return np.array(self.overwrite_target_spacing)\n\n        spacings = np.vstack(self.dataset_fingerprint['spacings'])\n        sizes = self.dataset_fingerprint['shapes_after_crop']\n\n        target = np.percentile(spacings, 50, 0)\n\n        # todo sizes_after_resampling = [compute_new_shape(j, i, target) for i, j in zip(spacings, sizes)]\n\n        target_size = np.percentile(np.vstack(sizes), 50, 0)\n        # we need to identify datasets for which a different target spacing could be beneficial. These datasets have\n        # the following properties:\n        # - one axis which much lower resolution than the others\n        # - the lowres axis has much less voxels than the others\n        # - (the size in mm of the lowres axis is also reduced)\n        worst_spacing_axis = np.argmax(target)\n        other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]\n        other_spacings = [target[i] for i in other_axes]\n        other_sizes = [target_size[i] for i in other_axes]\n\n        has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings))\n        has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes)\n\n        if has_aniso_spacing and has_aniso_voxels:\n            spacings_of_that_axis = spacings[:, worst_spacing_axis]\n            target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)\n            # don't let the spacing of that axis get higher than the other axes\n            if target_spacing_of_that_axis < max(other_spacings):\n                target_spacing_of_that_axis = max(max(other_spacings), target_spacing_of_that_axis) + 1e-5\n            target[worst_spacing_axis] = target_spacing_of_that_axis\n        return target\n\n    def determine_normalization_scheme_and_whether_mask_is_used_for_norm(self) -> Tuple[List[str], List[bool]]:\n        if 'channel_names' not in self.dataset_json.keys():\n            print('WARNING: \"modalities\" should be renamed to \"channel_names\" in dataset.json. This will be '\n                  'enforced soon!')\n        modalities = self.dataset_json['channel_names'] if 'channel_names' in self.dataset_json.keys() else \\\n            self.dataset_json['modality']\n        normalization_schemes = [get_normalization_scheme(m) for m in modalities.values()]\n        if self.dataset_fingerprint['median_relative_size_after_cropping'] < (3 / 4.):\n            use_nonzero_mask_for_norm = [i.leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true for i in\n                                         normalization_schemes]\n        else:\n            use_nonzero_mask_for_norm = [False] * len(normalization_schemes)\n            assert all([i in (True, False) for i in use_nonzero_mask_for_norm]), 'use_nonzero_mask_for_norm must be ' \\\n                                                                                 'True or False and cannot be None'\n        normalization_schemes = [i.__name__ for i in normalization_schemes]\n        return normalization_schemes, use_nonzero_mask_for_norm\n\n    def determine_transpose(self):\n        if self.suppress_transpose:\n            return [0, 1, 2], [0, 1, 2]\n\n        # todo we should use shapes for that as well. Not quite sure how yet\n        target_spacing = self.determine_fullres_target_spacing()\n\n        max_spacing_axis = np.argmax(target_spacing)\n        remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis]\n        transpose_forward = [max_spacing_axis] + remaining_axes\n        transpose_backward = [np.argwhere(np.array(transpose_forward) == i)[0][0] for i in range(3)]\n        return transpose_forward, transpose_backward\n\n    def get_plans_for_configuration(self,\n                                    spacing: Union[np.ndarray, Tuple[float, ...], List[float]],\n                                    median_shape: Union[np.ndarray, Tuple[int, ...]],\n                                    data_identifier: str,\n                                    approximate_n_voxels_dataset: float,\n                                    _cache: dict) -> dict:\n        def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]:\n            return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for\n                          i in range(num_stages)])\n\n        def _keygen(patch_size, strides):\n            return str(patch_size) + '_' + str(strides)\n\n        assert all([i > 0 for i in spacing]), f\"Spacing must be > 0! Spacing: {spacing}\"\n        num_input_channels = len(self.dataset_json['channel_names'].keys()\n                                 if 'channel_names' in self.dataset_json.keys()\n                                 else self.dataset_json['modality'].keys())\n        max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d\n        unet_conv_op = convert_dim_to_conv_op(len(spacing))\n\n        # print(spacing, median_shape, approximate_n_voxels_dataset)\n        # find an initial patch size\n        # we first use the spacing to get an aspect ratio\n        tmp = 1 / np.array(spacing)\n\n        # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same\n        # volume as a patch of size 256 ** 3)\n        # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be\n        # ideal because large initial patch sizes increase computation time because more iterations in the while loop\n        # further down may be required.\n        if len(spacing) == 3:\n            initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)]\n        elif len(spacing) == 2:\n            initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)]\n        else:\n            raise RuntimeError()\n\n        # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that\n        # this is different from how nnU-Net v1 does it!\n        # todo patch size can still get too large because we pad the patch size to a multiple of 2**n\n        initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])])\n\n        # use that to get the network topology. Note that this changes the patch_size depending on the number of\n        # pooling operations (must be divisible by 2**num_pool in each axis)\n        network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \\\n        shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size,\n                                                             self.UNet_featuremap_min_edge_length,\n                                                             999999)\n        num_stages = len(pool_op_kernel_sizes)\n\n        norm = get_matching_instancenorm(unet_conv_op)\n        architecture_kwargs = {\n            'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__,\n            'arch_kwargs': {\n                'n_stages': num_stages,\n                'features_per_stage': _features_per_stage(num_stages, max_num_features),\n                'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__,\n                'kernel_sizes': conv_kernel_sizes,\n                'strides': pool_op_kernel_sizes,\n                'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages],\n                'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],\n                'conv_bias': True,\n                'norm_op': norm.__module__ + '.' + norm.__name__,\n                'norm_op_kwargs': {'eps': 1e-5, 'affine': True},\n                'dropout_op': None,\n                'dropout_op_kwargs': None,\n                'nonlin': 'torch.nn.LeakyReLU',\n                'nonlin_kwargs': {'inplace': True},\n            },\n            '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'),\n        }\n\n        # now estimate vram consumption\n        if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys():\n            estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)]\n        else:\n            estimate = self.static_estimate_VRAM_usage(patch_size,\n                                                       num_input_channels,\n                                                       len(self.dataset_json['labels'].keys()),\n                                                       architecture_kwargs['network_class_name'],\n                                                       architecture_kwargs['arch_kwargs'],\n                                                       architecture_kwargs['_kw_requires_import'],\n                                                       )\n            _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate\n\n        # how large is the reference for us here (batch size etc)?\n        # adapt for our vram target\n        reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \\\n                    (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB)\n\n        ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d\n        # we enforce a batch size of at least two, reference values may have been computed for different batch sizes.\n        # Correct for that in the while loop if statement\n        while (estimate / ref_bs * 2) > reference:\n            # print(patch_size, estimate, reference)\n            # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the\n            # aspect ratio the most (that is the largest relative to median shape)\n            axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1]\n\n            # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this\n            # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256.\n            # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size\n            # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first\n            # subtract shape_must_be_divisible_by, then recompute it and then subtract the\n            # recomputed shape_must_be_divisible_by. Annoying.\n            patch_size = list(patch_size)\n            tmp = deepcopy(patch_size)\n            tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]\n            _, _, _, _, shape_must_be_divisible_by = \\\n                get_pool_and_conv_props(spacing, tmp,\n                                        self.UNet_featuremap_min_edge_length,\n                                        999999)\n            patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]\n\n            # now recompute topology\n            network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \\\n            shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size,\n                                                                 self.UNet_featuremap_min_edge_length,\n                                                                 999999)\n\n            num_stages = len(pool_op_kernel_sizes)\n            architecture_kwargs['arch_kwargs'].update({\n                'n_stages': num_stages,\n                'kernel_sizes': conv_kernel_sizes,\n                'strides': pool_op_kernel_sizes,\n                'features_per_stage': _features_per_stage(num_stages, max_num_features),\n                'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages],\n                'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],\n            })\n            if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys():\n                estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)]\n            else:\n                estimate = self.static_estimate_VRAM_usage(\n                    patch_size,\n                    num_input_channels,\n                    len(self.dataset_json['labels'].keys()),\n                    architecture_kwargs['network_class_name'],\n                    architecture_kwargs['arch_kwargs'],\n                    architecture_kwargs['_kw_requires_import'],\n                )\n                _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate\n\n        # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was\n        # executed. If not, additional vram headroom is used to increase batch size\n        batch_size = round((reference / estimate) * ref_bs)\n\n        # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot\n        # go smaller than self.UNet_min_batch_size though\n        bs_corresponding_to_5_percent = round(\n            approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64))\n        batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size)\n\n        resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling()\n        resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn()\n\n        normalization_schemes, mask_is_used_for_norm = \\\n            self.determine_normalization_scheme_and_whether_mask_is_used_for_norm()\n\n        plan = {\n            'data_identifier': data_identifier,\n            'preprocessor_name': self.preprocessor_name,\n            'batch_size': batch_size,\n            'patch_size': patch_size,\n            'median_image_size_in_voxels': median_shape,\n            'spacing': spacing,\n            'normalization_schemes': normalization_schemes,\n            'use_mask_for_norm': mask_is_used_for_norm,\n            'resampling_fn_data': resampling_data.__name__,\n            'resampling_fn_seg': resampling_seg.__name__,\n            'resampling_fn_data_kwargs': resampling_data_kwargs,\n            'resampling_fn_seg_kwargs': resampling_seg_kwargs,\n            'resampling_fn_probabilities': resampling_softmax.__name__,\n            'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs,\n            'architecture': architecture_kwargs\n        }\n        return plan\n\n    def plan_experiment(self):\n        \"\"\"\n        MOVE EVERYTHING INTO THE PLANS. MAXIMUM FLEXIBILITY\n\n        Ideally I would like to move transpose_forward/backward into the configurations so that this can also be done\n        differently for each configuration but this would cause problems with identifying the correct axes for 2d. There\n        surely is a way around that but eh. I'm feeling lazy and featuritis must also not be pushed to the extremes.\n\n        So for now if you want a different transpose_forward/backward you need to create a new planner. Also not too\n        hard.\n        \"\"\"\n        # we use this as a cache to prevent having to instantiate the architecture too often. Saves computation time\n        _tmp = {}\n\n        # first get transpose\n        transpose_forward, transpose_backward = self.determine_transpose()\n\n        # get fullres spacing and transpose it\n        fullres_spacing = self.determine_fullres_target_spacing()\n        fullres_spacing_transposed = fullres_spacing[transpose_forward]\n\n        # get transposed new median shape (what we would have after resampling)\n        new_shapes = [compute_new_shape(j, i, fullres_spacing) for i, j in\n                      zip(self.dataset_fingerprint['spacings'], self.dataset_fingerprint['shapes_after_crop'])]\n        new_median_shape = np.median(new_shapes, 0)\n        new_median_shape_transposed = new_median_shape[transpose_forward]\n\n        approximate_n_voxels_dataset = float(np.prod(new_median_shape_transposed, dtype=np.float64) *\n                                             self.dataset_json['numTraining'])\n        # only run 3d if this is a 3d dataset\n        if new_median_shape_transposed[0] != 1:\n            plan_3d_fullres = self.get_plans_for_configuration(fullres_spacing_transposed,\n                                                               new_median_shape_transposed,\n                                                               self.generate_data_identifier('3d_fullres'),\n                                                               approximate_n_voxels_dataset, _tmp)\n            # maybe add 3d_lowres as well\n            patch_size_fullres = plan_3d_fullres['patch_size']\n            median_num_voxels = np.prod(new_median_shape_transposed, dtype=np.float64)\n            num_voxels_in_patch = np.prod(patch_size_fullres, dtype=np.float64)\n\n            plan_3d_lowres = None\n            lowres_spacing = deepcopy(plan_3d_fullres['spacing'])\n\n            spacing_increase_factor = 1.03  # used to be 1.01 but that is slow with new GPU memory estimation!\n            while num_voxels_in_patch / median_num_voxels < self.lowres_creation_threshold:\n                # we incrementally increase the target spacing. We start with the anisotropic axis/axes until it/they\n                # is/are similar (factor 2) to the other ax(i/e)s.\n                max_spacing = max(lowres_spacing)\n                if np.any((max_spacing / lowres_spacing) > 2):\n                    lowres_spacing[(max_spacing / lowres_spacing) > 2] *= spacing_increase_factor\n                else:\n                    lowres_spacing *= spacing_increase_factor\n                median_num_voxels = np.prod(plan_3d_fullres['spacing'] / lowres_spacing * new_median_shape_transposed,\n                                            dtype=np.float64)\n                # print(lowres_spacing)\n                plan_3d_lowres = self.get_plans_for_configuration(lowres_spacing,\n                                                                  tuple([round(i) for i in plan_3d_fullres['spacing'] /\n                                                                         lowres_spacing * new_median_shape_transposed]),\n                                                                  self.generate_data_identifier('3d_lowres'),\n                                                                  float(np.prod(median_num_voxels) *\n                                                                        self.dataset_json['numTraining']), _tmp)\n                num_voxels_in_patch = np.prod(plan_3d_lowres['patch_size'], dtype=np.int64)\n                print(f'Attempting to find 3d_lowres config. '\n                      f'\\nCurrent spacing: {lowres_spacing}. '\n                      f'\\nCurrent patch size: {plan_3d_lowres[\"patch_size\"]}. '\n                      f'\\nCurrent median shape: {plan_3d_fullres[\"spacing\"] / lowres_spacing * new_median_shape_transposed}')\n            if np.prod(new_median_shape_transposed, dtype=np.float64) / median_num_voxels < 2:\n                print(f'Dropping 3d_lowres config because the image size difference to 3d_fullres is too small. '\n                      f'3d_fullres: {new_median_shape_transposed}, '\n                      f'3d_lowres: {[round(i) for i in plan_3d_fullres[\"spacing\"] / lowres_spacing * new_median_shape_transposed]}')\n                plan_3d_lowres = None\n            if plan_3d_lowres is not None:\n                plan_3d_lowres['batch_dice'] = False\n                plan_3d_fullres['batch_dice'] = True\n            else:\n                plan_3d_fullres['batch_dice'] = False\n        else:\n            plan_3d_fullres = None\n            plan_3d_lowres = None\n\n        # 2D configuration\n        plan_2d = self.get_plans_for_configuration(fullres_spacing_transposed[1:],\n                                                   new_median_shape_transposed[1:],\n                                                   self.generate_data_identifier('2d'), approximate_n_voxels_dataset,\n                                                   _tmp)\n        plan_2d['batch_dice'] = True\n\n        print('2D U-Net configuration:')\n        print(plan_2d)\n        print()\n\n        # median spacing and shape, just for reference when printing the plans\n        median_spacing = np.median(self.dataset_fingerprint['spacings'], 0)[transpose_forward]\n        median_shape = np.median(self.dataset_fingerprint['shapes_after_crop'], 0)[transpose_forward]\n\n        # instead of writing all that into the plans we just copy the original file. More files, but less crowded\n        # per file.\n        shutil.copy(join(self.raw_dataset_folder, 'dataset.json'),\n                    join(nnUNet_preprocessed, self.dataset_name, 'dataset.json'))\n\n        # json is ###. I hate it... \"Object of type int64 is not JSON serializable\"\n        plans = {\n            'dataset_name': self.dataset_name,\n            'plans_name': self.plans_identifier,\n            'original_median_spacing_after_transp': [float(i) for i in median_spacing],\n            'original_median_shape_after_transp': [int(round(i)) for i in median_shape],\n            'image_reader_writer': self.determine_reader_writer().__name__,\n            'transpose_forward': [int(i) for i in transpose_forward],\n            'transpose_backward': [int(i) for i in transpose_backward],\n            'configurations': {'2d': plan_2d},\n            'experiment_planner_used': self.__class__.__name__,\n            'label_manager': 'LabelManager',\n            'foreground_intensity_properties_per_channel': self.dataset_fingerprint[\n                'foreground_intensity_properties_per_channel']\n        }\n\n        if plan_3d_lowres is not None:\n            plans['configurations']['3d_lowres'] = plan_3d_lowres\n            if plan_3d_fullres is not None:\n                plans['configurations']['3d_lowres']['next_stage'] = '3d_cascade_fullres'\n            print('3D lowres U-Net configuration:')\n            print(plan_3d_lowres)\n            print()\n        if plan_3d_fullres is not None:\n            plans['configurations']['3d_fullres'] = plan_3d_fullres\n            print('3D fullres U-Net configuration:')\n            print(plan_3d_fullres)\n            print()\n            if plan_3d_lowres is not None:\n                plans['configurations']['3d_cascade_fullres'] = {\n                    'inherits_from': '3d_fullres',\n                    'previous_stage': '3d_lowres'\n                }\n\n        self.plans = plans\n        self.save_plans(plans)\n        return plans\n\n    def save_plans(self, plans):\n        recursive_fix_for_json_export(plans)\n\n        plans_file = join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json')\n\n        # we don't want to overwrite potentially existing custom configurations every time this is executed. So let's\n        # read the plans file if it already exists and keep any non-default configurations\n        if isfile(plans_file):\n            old_plans = load_json(plans_file)\n            old_configurations = old_plans['configurations']\n            for c in plans['configurations'].keys():\n                if c in old_configurations.keys():\n                    del (old_configurations[c])\n            plans['configurations'].update(old_configurations)\n\n        maybe_mkdir_p(join(nnUNet_preprocessed, self.dataset_name))\n        save_json(plans, plans_file, sort_keys=False)\n        print(f\"Plans were saved to {join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json')}\")\n\n    def generate_data_identifier(self, configuration_name: str) -> str:\n        \"\"\"\n        configurations are unique within each plans file but different plans file can have configurations with the\n        same name. In order to distinguish the associated data we need a data identifier that reflects not just the\n        config but also the plans it originates from\n        \"\"\"\n        return self.plans_identifier + '_' + configuration_name\n\n    def load_plans(self, fname: str):\n        self.plans = load_json(fname)\n\n\ndef _maybe_copy_splits_file(splits_file: str, target_fname: str):\n    if not isfile(target_fname):\n        shutil.copy(splits_file, target_fname)\n    else:\n        # split already exists, do not copy, but check that the splits match.\n        # This code allows target_fname to contain more splits than splits_file. This is OK.\n        splits_source = load_json(splits_file)\n        splits_target = load_json(target_fname)\n        # all folds in the source file must match the target file\n        for i in range(len(splits_source)):\n            train_source = set(splits_source[i]['train'])\n            train_target = set(splits_target[i]['train'])\n            assert train_target == train_source\n            val_source = set(splits_source[i]['val'])\n            val_target = set(splits_target[i]['val'])\n            assert val_source == val_target\n\n\nif __name__ == '__main__':\n    ExperimentPlanner(2, 8).plan_experiment()\n"
  },
  {
    "path": "nnunetv2/experiment_planning/experiment_planners/network_topology.py",
    "content": "from copy import deepcopy\nimport numpy as np\n\n\ndef get_shape_must_be_divisible_by(net_numpool_per_axis):\n    return 2 ** np.array(net_numpool_per_axis)\n\n\ndef pad_shape(shape, must_be_divisible_by):\n    \"\"\"\n    pads shape so that it is divisible by must_be_divisible_by\n    :param shape:\n    :param must_be_divisible_by:\n    :return:\n    \"\"\"\n    if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)):\n        must_be_divisible_by = [must_be_divisible_by] * len(shape)\n    else:\n        assert len(must_be_divisible_by) == len(shape)\n\n    new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))]\n\n    for i in range(len(shape)):\n        if shape[i] % must_be_divisible_by[i] == 0:\n            new_shp[i] -= must_be_divisible_by[i]\n    new_shp = np.array(new_shp).astype(int)\n    return new_shp\n\n\ndef get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool):\n    \"\"\"\n    this is the same as get_pool_and_conv_props_v2 from old nnunet\n\n    :param spacing:\n    :param patch_size:\n    :param min_feature_map_size: min edge length of feature maps in bottleneck\n    :param max_numpool:\n    :return:\n    \"\"\"\n    # todo review this code\n    dim = len(spacing)\n\n    current_spacing = deepcopy(list(spacing))\n    current_size = deepcopy(list(patch_size))\n\n    pool_op_kernel_sizes = [[1] * len(spacing)]\n    conv_kernel_sizes = []\n\n    num_pool_per_axis = [0] * dim\n    kernel_size = [1] * dim\n\n    while True:\n        # exclude axes that we cannot pool further because of min_feature_map_size constraint\n        valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size]\n        if len(valid_axes_for_pool) < 1:\n            break\n\n        spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool]\n\n        # find axis that are within factor of 2 within smallest spacing\n        min_spacing_of_valid = min(spacings_of_axes)\n        valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2]\n\n        # max_numpool constraint\n        valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool]\n\n        if len(valid_axes_for_pool) == 1:\n            if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size:\n                pass\n            else:\n                break\n        if len(valid_axes_for_pool) < 1:\n            break\n\n        # now we need to find kernel sizes\n        # kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within\n        # factor 2 of min_spacing. Once they are 3 they remain 3\n        for d in range(dim):\n            if kernel_size[d] == 3:\n                continue\n            else:\n                if current_spacing[d] / min(current_spacing) < 2:\n                    kernel_size[d] = 3\n\n        other_axes = [i for i in range(dim) if i not in valid_axes_for_pool]\n\n        pool_kernel_sizes = [0] * dim\n        for v in valid_axes_for_pool:\n            pool_kernel_sizes[v] = 2\n            num_pool_per_axis[v] += 1\n            current_spacing[v] *= 2\n            current_size[v] = np.ceil(current_size[v] / 2)\n        for nv in other_axes:\n            pool_kernel_sizes[nv] = 1\n\n        pool_op_kernel_sizes.append(pool_kernel_sizes)\n        conv_kernel_sizes.append(deepcopy(kernel_size))\n        #print(conv_kernel_sizes)\n\n    must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)\n    patch_size = pad_shape(patch_size, must_be_divisible_by)\n\n    def _to_tuple(lst):\n        return tuple(_to_tuple(i) if isinstance(i, list) else i for i in lst)\n\n    # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here\n    conv_kernel_sizes.append([3]*dim)\n    return num_pool_per_axis, _to_tuple(pool_op_kernel_sizes), _to_tuple(conv_kernel_sizes), tuple(patch_size), must_be_divisible_by\n"
  },
  {
    "path": "nnunetv2/experiment_planning/experiment_planners/resampling/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/experiment_planning/experiment_planners/resampling/planners_no_resampling.py",
    "content": "from typing import Union, List, Tuple\n\nfrom nnunetv2.experiment_planning.experiment_planners.residual_unets.residual_encoder_unet_planners import \\\n    nnUNetPlannerResEncL\nfrom nnunetv2.preprocessing.resampling.no_resampling import no_resampling_hack\n\n\nclass nnUNetPlannerResEncL_noResampling(nnUNetPlannerResEncL):\n    \"\"\"\n    This planner will generate 3d_lowres as well. Don't trust it. Everything will remain in the original shape.\n    No resampling will ever be done.\n    \"\"\"\n    def __init__(self, dataset_name_or_id: Union[str, int],\n                 gpu_memory_target_in_gb: float = 24,\n                 preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_noResampling',\n                 overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,\n                 suppress_transpose: bool = False):\n        super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,\n                         overwrite_target_spacing, suppress_transpose)\n\n    def generate_data_identifier(self, configuration_name: str) -> str:\n        \"\"\"\n        configurations are unique within each plans file but different plans file can have configurations with the\n        same name. In order to distinguish the associated data we need a data identifier that reflects not just the\n        config but also the plans it originates from\n        \"\"\"\n        return self.plans_identifier + '_' + configuration_name\n\n    def determine_resampling(self, *args, **kwargs):\n        \"\"\"\n        returns what functions to use for resampling data and seg, respectively. Also returns kwargs\n        resampling function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs)\n\n        determine_resampling is called within get_plans_for_configuration to allow for different functions for each\n        configuration\n        \"\"\"\n        resampling_data = no_resampling_hack\n        resampling_data_kwargs = {}\n        resampling_seg = no_resampling_hack\n        resampling_seg_kwargs = {}\n        return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs\n\n    def determine_segmentation_softmax_export_fn(self, *args, **kwargs):\n        \"\"\"\n        function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be\n        used as target. current_spacing and new_spacing are merely there in case we want to use it somehow\n\n        determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different\n        functions for each configuration\n\n        \"\"\"\n        resampling_fn = no_resampling_hack\n        resampling_fn_kwargs = {}\n        return resampling_fn, resampling_fn_kwargs\n"
  },
  {
    "path": "nnunetv2/experiment_planning/experiment_planners/resampling/resample_with_torch.py",
    "content": "from typing import Union, List, Tuple\n\nfrom nnunetv2.configuration import ANISO_THRESHOLD\nfrom nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner\nfrom nnunetv2.experiment_planning.experiment_planners.residual_unets.residual_encoder_unet_planners import \\\n    nnUNetPlannerResEncL\nfrom nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet\n\n\nclass nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL):\n    def __init__(self, dataset_name_or_id: Union[str, int],\n                 gpu_memory_target_in_gb: float = 24,\n                 preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres',\n                 overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,\n                 suppress_transpose: bool = False):\n        super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,\n                         overwrite_target_spacing, suppress_transpose)\n\n    def generate_data_identifier(self, configuration_name: str) -> str:\n        \"\"\"\n        configurations are unique within each plans file but different plans file can have configurations with the\n        same name. In order to distinguish the associated data we need a data identifier that reflects not just the\n        config but also the plans it originates from\n        \"\"\"\n        return self.plans_identifier + '_' + configuration_name\n\n    def determine_resampling(self, *args, **kwargs):\n        \"\"\"\n        returns what functions to use for resampling data and seg, respectively. Also returns kwargs\n        resampling function must be callable(data, current_spacing, new_spacing, **kwargs)\n\n        determine_resampling is called within get_plans_for_configuration to allow for different functions for each\n        configuration\n        \"\"\"\n        resampling_data = resample_torch_fornnunet\n        resampling_data_kwargs = {\n            \"is_seg\": False,\n            'force_separate_z': False,\n            'memefficient_seg_resampling': False\n        }\n        resampling_seg = resample_torch_fornnunet\n        resampling_seg_kwargs = {\n            \"is_seg\": True,\n            'force_separate_z': False,\n            'memefficient_seg_resampling': False\n        }\n        return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs\n\n    def determine_segmentation_softmax_export_fn(self, *args, **kwargs):\n        \"\"\"\n        function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be\n        used as target. current_spacing and new_spacing are merely there in case we want to use it somehow\n\n        determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different\n        functions for each configuration\n\n        \"\"\"\n        resampling_fn = resample_torch_fornnunet\n        resampling_fn_kwargs = {\n            \"is_seg\": False,\n            'force_separate_z': False,\n            'memefficient_seg_resampling': False\n        }\n        return resampling_fn, resampling_fn_kwargs\n\n\nclass nnUNetPlannerResEncL_torchres_sepz(nnUNetPlannerResEncL):\n    def __init__(self, dataset_name_or_id: Union[str, int],\n                 gpu_memory_target_in_gb: float = 24,\n                 preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres_sepz',\n                 overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,\n                 suppress_transpose: bool = False):\n        super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,\n                         overwrite_target_spacing, suppress_transpose)\n\n    def generate_data_identifier(self, configuration_name: str) -> str:\n        \"\"\"\n        configurations are unique within each plans file but different plans file can have configurations with the\n        same name. In order to distinguish the associated data we need a data identifier that reflects not just the\n        config but also the plans it originates from\n        \"\"\"\n        return self.plans_identifier + '_' + configuration_name\n\n    def determine_resampling(self, *args, **kwargs):\n        \"\"\"\n        returns what functions to use for resampling data and seg, respectively. Also returns kwargs\n        resampling function must be callable(data, current_spacing, new_spacing, **kwargs)\n\n        determine_resampling is called within get_plans_for_configuration to allow for different functions for each\n        configuration\n        \"\"\"\n        resampling_data = resample_torch_fornnunet\n        resampling_data_kwargs = {\n            \"is_seg\": False,\n            'force_separate_z': None,\n            'memefficient_seg_resampling': False,\n            'separate_z_anisotropy_threshold': ANISO_THRESHOLD\n        }\n        resampling_seg = resample_torch_fornnunet\n        resampling_seg_kwargs = {\n            \"is_seg\": True,\n            'force_separate_z': None,\n            'memefficient_seg_resampling': False,\n            'separate_z_anisotropy_threshold': ANISO_THRESHOLD\n        }\n        return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs\n\n    def determine_segmentation_softmax_export_fn(self, *args, **kwargs):\n        \"\"\"\n        function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be\n        used as target. current_spacing and new_spacing are merely there in case we want to use it somehow\n\n        determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different\n        functions for each configuration\n\n        \"\"\"\n        resampling_fn = resample_torch_fornnunet\n        resampling_fn_kwargs = {\n            \"is_seg\": False,\n            'force_separate_z': None,\n            'memefficient_seg_resampling': False,\n            'separate_z_anisotropy_threshold': ANISO_THRESHOLD\n        }\n        return resampling_fn, resampling_fn_kwargs\n\n\nclass nnUNetPlanner_torchres(ExperimentPlanner):\n    def __init__(self, dataset_name_or_id: Union[str, int],\n                 gpu_memory_target_in_gb: float = 8,\n                 preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans_torchres',\n                 overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,\n                 suppress_transpose: bool = False):\n        super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,\n                         overwrite_target_spacing, suppress_transpose)\n\n    def generate_data_identifier(self, configuration_name: str) -> str:\n        \"\"\"\n        configurations are unique within each plans file but different plans file can have configurations with the\n        same name. In order to distinguish the associated data we need a data identifier that reflects not just the\n        config but also the plans it originates from\n        \"\"\"\n        return self.plans_identifier + '_' + configuration_name\n\n    def determine_resampling(self, *args, **kwargs):\n        \"\"\"\n        returns what functions to use for resampling data and seg, respectively. Also returns kwargs\n        resampling function must be callable(data, current_spacing, new_spacing, **kwargs)\n\n        determine_resampling is called within get_plans_for_configuration to allow for different functions for each\n        configuration\n        \"\"\"\n        resampling_data = resample_torch_fornnunet\n        resampling_data_kwargs = {\n            \"is_seg\": False,\n            'force_separate_z': False,\n            'memefficient_seg_resampling': False\n        }\n        resampling_seg = resample_torch_fornnunet\n        resampling_seg_kwargs = {\n            \"is_seg\": True,\n            'force_separate_z': False,\n            'memefficient_seg_resampling': False\n        }\n        return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs\n\n    def determine_segmentation_softmax_export_fn(self, *args, **kwargs):\n        \"\"\"\n        function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be\n        used as target. current_spacing and new_spacing are merely there in case we want to use it somehow\n\n        determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different\n        functions for each configuration\n\n        \"\"\"\n        resampling_fn = resample_torch_fornnunet\n        resampling_fn_kwargs = {\n            \"is_seg\": False,\n            'force_separate_z': False,\n            'memefficient_seg_resampling': False\n        }\n        return resampling_fn, resampling_fn_kwargs\n"
  },
  {
    "path": "nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py",
    "content": "import numpy as np\nfrom copy import deepcopy\nfrom typing import Union, List, Tuple\n\nfrom dynamic_network_architectures.architectures.unet import ResidualEncoderUNet\nfrom dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm\nfrom torch import nn\n\nfrom nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner\n\nfrom nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props\n\n\nclass ResEncUNetPlanner(ExperimentPlanner):\n    def __init__(self, dataset_name_or_id: Union[str, int],\n                 gpu_memory_target_in_gb: float = 8,\n                 preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetPlans',\n                 overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,\n                 suppress_transpose: bool = False):\n        super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,\n                         overwrite_target_spacing, suppress_transpose)\n        self.UNet_class = ResidualEncoderUNet\n        # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as\n        # much as possible\n        self.UNet_reference_val_3d = 680000000\n        self.UNet_reference_val_2d = 135000000\n        self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6)\n        self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)\n\n    def generate_data_identifier(self, configuration_name: str) -> str:\n        \"\"\"\n        configurations are unique within each plans file but different plans file can have configurations with the\n        same name. In order to distinguish the associated data we need a data identifier that reflects not just the\n        config but also the plans it originates from\n        \"\"\"\n        if configuration_name == '2d' or configuration_name == '3d_fullres':\n            # we do not deviate from ExperimentPlanner so we can reuse its data\n            return 'nnUNetPlans' + '_' + configuration_name\n        else:\n            return self.plans_identifier + '_' + configuration_name\n\n    def get_plans_for_configuration(self,\n                                    spacing: Union[np.ndarray, Tuple[float, ...], List[float]],\n                                    median_shape: Union[np.ndarray, Tuple[int, ...]],\n                                    data_identifier: str,\n                                    approximate_n_voxels_dataset: float,\n                                    _cache: dict) -> dict:\n        def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]:\n            return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for\n                          i in range(num_stages)])\n\n        def _keygen(patch_size, strides):\n            return str(patch_size) + '_' + str(strides)\n\n        assert all([i > 0 for i in spacing]), f\"Spacing must be > 0! Spacing: {spacing}\"\n        num_input_channels = len(self.dataset_json['channel_names'].keys()\n                                 if 'channel_names' in self.dataset_json.keys()\n                                 else self.dataset_json['modality'].keys())\n        max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d\n        unet_conv_op = convert_dim_to_conv_op(len(spacing))\n\n        # print(spacing, median_shape, approximate_n_voxels_dataset)\n        # find an initial patch size\n        # we first use the spacing to get an aspect ratio\n        tmp = 1 / np.array(spacing)\n\n        # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same\n        # volume as a patch of size 256 ** 3)\n        # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be\n        # ideal because large initial patch sizes increase computation time because more iterations in the while loop\n        # further down may be required.\n        if len(spacing) == 3:\n            initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)]\n        elif len(spacing) == 2:\n            initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)]\n        else:\n            raise RuntimeError()\n\n        # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that\n        # this is different from how nnU-Net v1 does it!\n        # todo patch size can still get too large because we pad the patch size to a multiple of 2**n\n        initial_patch_size = np.minimum(initial_patch_size, median_shape[:len(spacing)])\n\n        # use that to get the network topology. Note that this changes the patch_size depending on the number of\n        # pooling operations (must be divisible by 2**num_pool in each axis)\n        network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \\\n        shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size,\n                                                             self.UNet_featuremap_min_edge_length,\n                                                             999999)\n        num_stages = len(pool_op_kernel_sizes)\n\n        norm = get_matching_instancenorm(unet_conv_op)\n        architecture_kwargs = {\n            'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__,\n            'arch_kwargs': {\n                'n_stages': num_stages,\n                'features_per_stage': _features_per_stage(num_stages, max_num_features),\n                'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__,\n                'kernel_sizes': conv_kernel_sizes,\n                'strides': pool_op_kernel_sizes,\n                'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages],\n                'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],\n                'conv_bias': True,\n                'norm_op': norm.__module__ + '.' + norm.__name__,\n                'norm_op_kwargs': {'eps': 1e-5, 'affine': True},\n                'dropout_op': None,\n                'dropout_op_kwargs': None,\n                'nonlin': 'torch.nn.LeakyReLU',\n                'nonlin_kwargs': {'inplace': True},\n            },\n            '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'),\n        }\n\n        # now estimate vram consumption\n        if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys():\n            estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)]\n        else:\n            estimate = self.static_estimate_VRAM_usage(patch_size,\n                                                       num_input_channels,\n                                                       len(self.dataset_json['labels'].keys()),\n                                                       architecture_kwargs['network_class_name'],\n                                                       architecture_kwargs['arch_kwargs'],\n                                                       architecture_kwargs['_kw_requires_import'],\n                                                       )\n            _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate\n\n        # how large is the reference for us here (batch size etc)?\n        # adapt for our vram target\n        reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \\\n                    (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB)\n\n        while estimate > reference:\n            # print(patch_size)\n            # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the\n            # aspect ratio the most (that is the largest relative to median shape)\n            axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1]\n\n            # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this\n            # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256.\n            # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size\n            # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first\n            # subtract shape_must_be_divisible_by, then recompute it and then subtract the\n            # recomputed shape_must_be_divisible_by. Annoying.\n            patch_size = list(patch_size)\n            tmp = deepcopy(patch_size)\n            tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]\n            _, _, _, _, shape_must_be_divisible_by = \\\n                get_pool_and_conv_props(spacing, tmp,\n                                        self.UNet_featuremap_min_edge_length,\n                                        999999)\n            patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]\n\n            # now recompute topology\n            network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \\\n            shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size,\n                                                                 self.UNet_featuremap_min_edge_length,\n                                                                 999999)\n\n            num_stages = len(pool_op_kernel_sizes)\n            architecture_kwargs['arch_kwargs'].update({\n                'n_stages': num_stages,\n                'kernel_sizes': conv_kernel_sizes,\n                'strides': pool_op_kernel_sizes,\n                'features_per_stage': _features_per_stage(num_stages, max_num_features),\n                'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages],\n                'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],\n            })\n            if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys():\n                estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)]\n            else:\n                estimate = self.static_estimate_VRAM_usage(\n                    patch_size,\n                    num_input_channels,\n                    len(self.dataset_json['labels'].keys()),\n                    architecture_kwargs['network_class_name'],\n                    architecture_kwargs['arch_kwargs'],\n                    architecture_kwargs['_kw_requires_import'],\n                )\n                _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate\n\n        # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was\n        # executed. If not, additional vram headroom is used to increase batch size\n        ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d\n        batch_size = round((reference / estimate) * ref_bs)\n\n        # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot\n        # go smaller than self.UNet_min_batch_size though\n        bs_corresponding_to_5_percent = round(\n            approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64))\n        batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size)\n\n        resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling()\n        resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn()\n\n        normalization_schemes, mask_is_used_for_norm = \\\n            self.determine_normalization_scheme_and_whether_mask_is_used_for_norm()\n\n        plan = {\n            'data_identifier': data_identifier,\n            'preprocessor_name': self.preprocessor_name,\n            'batch_size': batch_size,\n            'patch_size': patch_size,\n            'median_image_size_in_voxels': median_shape,\n            'spacing': spacing,\n            'normalization_schemes': normalization_schemes,\n            'use_mask_for_norm': mask_is_used_for_norm,\n            'resampling_fn_data': resampling_data.__name__,\n            'resampling_fn_seg': resampling_seg.__name__,\n            'resampling_fn_data_kwargs': resampling_data_kwargs,\n            'resampling_fn_seg_kwargs': resampling_seg_kwargs,\n            'resampling_fn_probabilities': resampling_softmax.__name__,\n            'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs,\n            'architecture': architecture_kwargs\n        }\n        return plan\n\n\nif __name__ == '__main__':\n    # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively\n    net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320),\n                              conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2),\n                              n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3,\n                              n_conv_per_stage_decoder=(1, 1, 1, 1, 1),\n                              conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None,\n                              nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True)\n    print(net.compute_conv_feature_map_size((128, 128, 128)))  # -> 558319104. The value you see above was finetuned\n    # from this one to match the regular nnunetplans more closely\n\n    net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512),\n                              conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2),\n                              n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3,\n                              n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1),\n                              conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None,\n                              nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True)\n    print(net.compute_conv_feature_map_size((512, 512)))  # -> 129793792\n"
  },
  {
    "path": "nnunetv2/experiment_planning/experiment_planners/residual_unets/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py",
    "content": "import warnings\n\nimport numpy as np\nfrom copy import deepcopy\nfrom typing import Union, List, Tuple\n\nfrom dynamic_network_architectures.architectures.unet import ResidualEncoderUNet\nfrom dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm\nfrom nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet\nfrom torch import nn\n\nfrom nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner\n\nfrom nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props\n\n\nclass ResEncUNetPlanner(ExperimentPlanner):\n    def __init__(self, dataset_name_or_id: Union[str, int],\n                 gpu_memory_target_in_gb: float = 8,\n                 preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetPlans',\n                 overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,\n                 suppress_transpose: bool = False):\n        super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,\n                         overwrite_target_spacing, suppress_transpose)\n        self.UNet_class = ResidualEncoderUNet\n        # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as\n        # much as possible\n        self.UNet_reference_val_3d = 680000000\n        self.UNet_reference_val_2d = 135000000\n        self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6)\n        self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)\n\n    def generate_data_identifier(self, configuration_name: str) -> str:\n        \"\"\"\n        configurations are unique within each plans file but different plans file can have configurations with the\n        same name. In order to distinguish the associated data we need a data identifier that reflects not just the\n        config but also the plans it originates from\n        \"\"\"\n        if configuration_name == '2d' or configuration_name == '3d_fullres':\n            # we do not deviate from ExperimentPlanner so we can reuse its data\n            return 'nnUNetPlans' + '_' + configuration_name\n        else:\n            return self.plans_identifier + '_' + configuration_name\n\n    def get_plans_for_configuration(self,\n                                    spacing: Union[np.ndarray, Tuple[float, ...], List[float]],\n                                    median_shape: Union[np.ndarray, Tuple[int, ...]],\n                                    data_identifier: str,\n                                    approximate_n_voxels_dataset: float,\n                                    _cache: dict) -> dict:\n        def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]:\n            return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for\n                          i in range(num_stages)])\n\n        def _keygen(patch_size, strides):\n            return str(patch_size) + '_' + str(strides)\n\n        assert all([i > 0 for i in spacing]), f\"Spacing must be > 0! Spacing: {spacing}\"\n        num_input_channels = len(self.dataset_json['channel_names'].keys()\n                                 if 'channel_names' in self.dataset_json.keys()\n                                 else self.dataset_json['modality'].keys())\n        max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d\n        unet_conv_op = convert_dim_to_conv_op(len(spacing))\n\n        # print(spacing, median_shape, approximate_n_voxels_dataset)\n        # find an initial patch size\n        # we first use the spacing to get an aspect ratio\n        tmp = 1 / np.array(spacing)\n\n        # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same\n        # volume as a patch of size 256 ** 3)\n        # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be\n        # ideal because large initial patch sizes increase computation time because more iterations in the while loop\n        # further down may be required.\n        if len(spacing) == 3:\n            initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)]\n        elif len(spacing) == 2:\n            initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)]\n        else:\n            raise RuntimeError()\n\n        # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that\n        # this is different from how nnU-Net v1 does it!\n        # todo patch size can still get too large because we pad the patch size to a multiple of 2**n\n        initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])])\n\n        # use that to get the network topology. Note that this changes the patch_size depending on the number of\n        # pooling operations (must be divisible by 2**num_pool in each axis)\n        network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \\\n        shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size,\n                                                             self.UNet_featuremap_min_edge_length,\n                                                             999999)\n        num_stages = len(pool_op_kernel_sizes)\n\n        norm = get_matching_instancenorm(unet_conv_op)\n        architecture_kwargs = {\n            'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__,\n            'arch_kwargs': {\n                'n_stages': num_stages,\n                'features_per_stage': _features_per_stage(num_stages, max_num_features),\n                'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__,\n                'kernel_sizes': conv_kernel_sizes,\n                'strides': pool_op_kernel_sizes,\n                'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages],\n                'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],\n                'conv_bias': True,\n                'norm_op': norm.__module__ + '.' + norm.__name__,\n                'norm_op_kwargs': {'eps': 1e-5, 'affine': True},\n                'dropout_op': None,\n                'dropout_op_kwargs': None,\n                'nonlin': 'torch.nn.LeakyReLU',\n                'nonlin_kwargs': {'inplace': True},\n            },\n            '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'),\n        }\n\n        # now estimate vram consumption\n        if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys():\n            estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)]\n        else:\n            estimate = self.static_estimate_VRAM_usage(patch_size,\n                                                       num_input_channels,\n                                                       len(self.dataset_json['labels'].keys()),\n                                                       architecture_kwargs['network_class_name'],\n                                                       architecture_kwargs['arch_kwargs'],\n                                                       architecture_kwargs['_kw_requires_import'],\n                                                       )\n            _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate\n\n        # how large is the reference for us here (batch size etc)?\n        # adapt for our vram target\n        reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \\\n                    (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB)\n\n        while estimate > reference:\n            # print(patch_size)\n            # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the\n            # aspect ratio the most (that is the largest relative to median shape)\n            axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1]\n\n            # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this\n            # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256.\n            # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size\n            # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first\n            # subtract shape_must_be_divisible_by, then recompute it and then subtract the\n            # recomputed shape_must_be_divisible_by. Annoying.\n            patch_size = list(patch_size)\n            tmp = deepcopy(patch_size)\n            tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]\n            _, _, _, _, shape_must_be_divisible_by = \\\n                get_pool_and_conv_props(spacing, tmp,\n                                        self.UNet_featuremap_min_edge_length,\n                                        999999)\n            patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]\n\n            # now recompute topology\n            network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \\\n            shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size,\n                                                                 self.UNet_featuremap_min_edge_length,\n                                                                 999999)\n\n            num_stages = len(pool_op_kernel_sizes)\n            architecture_kwargs['arch_kwargs'].update({\n                'n_stages': num_stages,\n                'kernel_sizes': conv_kernel_sizes,\n                'strides': pool_op_kernel_sizes,\n                'features_per_stage': _features_per_stage(num_stages, max_num_features),\n                'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages],\n                'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],\n            })\n            if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys():\n                estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)]\n            else:\n                estimate = self.static_estimate_VRAM_usage(\n                    patch_size,\n                    num_input_channels,\n                    len(self.dataset_json['labels'].keys()),\n                    architecture_kwargs['network_class_name'],\n                    architecture_kwargs['arch_kwargs'],\n                    architecture_kwargs['_kw_requires_import'],\n                )\n                _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate\n\n        # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was\n        # executed. If not, additional vram headroom is used to increase batch size\n        ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d\n        batch_size = round((reference / estimate) * ref_bs)\n\n        # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot\n        # go smaller than self.UNet_min_batch_size though\n        bs_corresponding_to_5_percent = round(\n            approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64))\n        batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size)\n\n        resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling()\n        resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn()\n\n        normalization_schemes, mask_is_used_for_norm = \\\n            self.determine_normalization_scheme_and_whether_mask_is_used_for_norm()\n\n        plan = {\n            'data_identifier': data_identifier,\n            'preprocessor_name': self.preprocessor_name,\n            'batch_size': batch_size,\n            'patch_size': patch_size,\n            'median_image_size_in_voxels': median_shape,\n            'spacing': spacing,\n            'normalization_schemes': normalization_schemes,\n            'use_mask_for_norm': mask_is_used_for_norm,\n            'resampling_fn_data': resampling_data.__name__,\n            'resampling_fn_seg': resampling_seg.__name__,\n            'resampling_fn_data_kwargs': resampling_data_kwargs,\n            'resampling_fn_seg_kwargs': resampling_seg_kwargs,\n            'resampling_fn_probabilities': resampling_softmax.__name__,\n            'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs,\n            'architecture': architecture_kwargs\n        }\n        return plan\n\n\nclass nnUNetPlannerResEncM(ResEncUNetPlanner):\n    \"\"\"\n    Target is ~9-11 GB VRAM max -> older Titan, RTX 2080ti\n    \"\"\"\n    def __init__(self, dataset_name_or_id: Union[str, int],\n                 gpu_memory_target_in_gb: float = 8,\n                 preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetMPlans',\n                 overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,\n                 suppress_transpose: bool = False):\n        if gpu_memory_target_in_gb != 8:\n            warnings.warn(\"WARNING: You are running nnUNetPlannerM with a non-standard gpu_memory_target_in_gb. \"\n                          f\"Expected 8, got {gpu_memory_target_in_gb}.\"\n                          \"You should only see this warning if you modified this value intentionally!!\")\n        super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,\n                         overwrite_target_spacing, suppress_transpose)\n        self.UNet_class = ResidualEncoderUNet\n\n        self.UNet_vram_target_GB = gpu_memory_target_in_gb\n        self.UNet_reference_val_corresp_GB = 8\n\n        # this is supposed to give the same GPU memory requirement as the default nnU-Net\n        self.UNet_reference_val_3d = 680000000\n        self.UNet_reference_val_2d = 135000000\n        self.max_dataset_covered = 1\n\n\nclass nnUNetPlannerResEncL(ResEncUNetPlanner):\n    \"\"\"\n    Target is ~24 GB VRAM max -> RTX 4090, Titan RTX, Quadro 6000\n    \"\"\"\n    def __init__(self, dataset_name_or_id: Union[str, int],\n                 gpu_memory_target_in_gb: float = 24,\n                 preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans',\n                 overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,\n                 suppress_transpose: bool = False):\n        if gpu_memory_target_in_gb != 24:\n            warnings.warn(\"WARNING: You are running nnUNetPlannerL with a non-standard gpu_memory_target_in_gb. \"\n                          f\"Expected 24, got {gpu_memory_target_in_gb}.\"\n                          \"You should only see this warning if you modified this value intentionally!!\")\n        super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,\n                         overwrite_target_spacing, suppress_transpose)\n        self.UNet_class = ResidualEncoderUNet\n\n        self.UNet_vram_target_GB = gpu_memory_target_in_gb\n        self.UNet_reference_val_corresp_GB = 24\n\n        self.UNet_reference_val_3d = 2100000000  # 1840000000\n        self.UNet_reference_val_2d = 380000000  # 352666667\n        self.max_dataset_covered = 1\n\n\nclass nnUNetPlannerResEncXL(ResEncUNetPlanner):\n    \"\"\"\n    Target is 40 GB VRAM max -> A100 40GB, RTX 6000 Ada Generation\n    \"\"\"\n    def __init__(self, dataset_name_or_id: Union[str, int],\n                 gpu_memory_target_in_gb: float = 40,\n                 preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLPlans',\n                 overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,\n                 suppress_transpose: bool = False):\n        if gpu_memory_target_in_gb != 40:\n            warnings.warn(\"WARNING: You are running nnUNetPlannerXL with a non-standard gpu_memory_target_in_gb. \"\n                          f\"Expected 40, got {gpu_memory_target_in_gb}.\"\n                          \"You should only see this warning if you modified this value intentionally!!\")\n        super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,\n                         overwrite_target_spacing, suppress_transpose)\n        self.UNet_class = ResidualEncoderUNet\n\n        self.UNet_vram_target_GB = gpu_memory_target_in_gb\n        self.UNet_reference_val_corresp_GB = 40\n\n        self.UNet_reference_val_3d = 3600000000\n        self.UNet_reference_val_2d = 560000000\n        self.max_dataset_covered = 1\n\n\nif __name__ == '__main__':\n    # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively\n    net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320),\n                              conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2),\n                              n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3,\n                              n_conv_per_stage_decoder=(1, 1, 1, 1, 1),\n                              conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None,\n                              nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True)\n    print(net.compute_conv_feature_map_size((128, 128, 128)))  # -> 558319104. The value you see above was finetuned\n    # from this one to match the regular nnunetplans more closely\n\n    net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512),\n                              conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2),\n                              n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3,\n                              n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1),\n                              conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None,\n                              nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True)\n    print(net.compute_conv_feature_map_size((512, 512)))  # -> 129793792\n"
  },
  {
    "path": "nnunetv2/experiment_planning/plan_and_preprocess_api.py",
    "content": "import warnings\nfrom typing import List, Type, Optional, Tuple, Union\n\nfrom batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, load_json\n\nimport nnunetv2\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.experiment_planning.dataset_fingerprint.fingerprint_extractor import DatasetFingerprintExtractor\nfrom nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner\nfrom nnunetv2.experiment_planning.verify_dataset_integrity import verify_dataset_integrity\nfrom nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed\nfrom nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name\nfrom nnunetv2.utilities.find_class_by_name import recursive_find_python_class\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager\nfrom nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets\n\n\ndef extract_fingerprint_dataset(dataset_id: int,\n                                fingerprint_extractor_class: Type[\n                                    DatasetFingerprintExtractor] = DatasetFingerprintExtractor,\n                                num_processes: int = default_num_processes, check_dataset_integrity: bool = False,\n                                clean: bool = True, verbose: bool = True,\n                                show_progress_bar: bool = True):\n    \"\"\"\n    Returns the fingerprint as a dictionary (additionally to saving it)\n    \"\"\"\n    dataset_name = convert_id_to_dataset_name(dataset_id)\n    print(dataset_name)\n\n    if check_dataset_integrity:\n        verify_dataset_integrity(join(nnUNet_raw, dataset_name), num_processes)\n\n    fpe = fingerprint_extractor_class(dataset_id, num_processes, verbose=verbose)\n    if hasattr(fpe, 'show_progress_bar'):\n        fpe.show_progress_bar = show_progress_bar\n    return fpe.run(overwrite_existing=clean)\n\n\ndef extract_fingerprints(dataset_ids: List[int], fingerprint_extractor_class_name: str = 'DatasetFingerprintExtractor',\n                         num_processes: int = default_num_processes, check_dataset_integrity: bool = False,\n                         clean: bool = True, verbose: bool = True,\n                         show_progress_bar: bool = True):\n    \"\"\"\n    clean = False will not actually run this. This is just a switch for use with nnUNetv2_plan_and_preprocess where\n    we don't want to rerun fingerprint extraction every time.\n    \"\"\"\n    fingerprint_extractor_class = recursive_find_python_class(join(nnunetv2.__path__[0], \"experiment_planning\"),\n                                                              fingerprint_extractor_class_name,\n                                                              current_module=\"nnunetv2.experiment_planning\")\n    for d in dataset_ids:\n        extract_fingerprint_dataset(d, fingerprint_extractor_class, num_processes, check_dataset_integrity, clean,\n                                    verbose, show_progress_bar)\n\n\ndef plan_experiment_dataset(dataset_id: int,\n                            experiment_planner_class: Type[ExperimentPlanner] = ExperimentPlanner,\n                            gpu_memory_target_in_gb: float = None, preprocess_class_name: str = 'DefaultPreprocessor',\n                            overwrite_target_spacing: Optional[Tuple[float, ...]] = None,\n                            overwrite_plans_name: Optional[str] = None) -> Tuple[dict, str]:\n    \"\"\"\n    overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres!\n    \"\"\"\n    kwargs = {}\n    if overwrite_plans_name is not None:\n        kwargs['plans_name'] = overwrite_plans_name\n    if gpu_memory_target_in_gb is not None:\n        kwargs['gpu_memory_target_in_gb'] = gpu_memory_target_in_gb\n\n    planner = experiment_planner_class(dataset_id,\n                                       preprocessor_name=preprocess_class_name,\n                                       overwrite_target_spacing=[float(i) for i in overwrite_target_spacing] if\n                                       overwrite_target_spacing is not None else overwrite_target_spacing,\n                                       suppress_transpose=False,  # might expose this later,\n                                       **kwargs\n                                       )\n    ret = planner.plan_experiment()\n    return ret, planner.plans_identifier\n\n\ndef plan_experiments(dataset_ids: List[int], experiment_planner_class_name: str = 'ExperimentPlanner',\n                     gpu_memory_target_in_gb: float = None, preprocess_class_name: str = 'DefaultPreprocessor',\n                     overwrite_target_spacing: Optional[Tuple[float, ...]] = None,\n                     overwrite_plans_name: Optional[str] = None):\n    \"\"\"\n    overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres!\n    \"\"\"\n    if experiment_planner_class_name == 'ExperimentPlanner':\n        print(\"\\n############################\\n\"\n              \"INFO: You are using the old nnU-Net default planner. We have updated our recommendations. \"\n              \"Please consider using those instead! \"\n              \"Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md\"\n              \"\\n############################\\n\")\n    experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], \"experiment_planning\"),\n                                                     experiment_planner_class_name,\n                                                     current_module=\"nnunetv2.experiment_planning\")\n    plans_identifier = None\n    for d in dataset_ids:\n        _, plans_identifier = plan_experiment_dataset(d, experiment_planner, gpu_memory_target_in_gb,\n                                                      preprocess_class_name,\n                                                      overwrite_target_spacing, overwrite_plans_name)\n    return plans_identifier\n\n\ndef preprocess_dataset(dataset_id: int,\n                       plans_identifier: str = 'nnUNetPlans',\n                       configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'),\n                       num_processes: Union[int, Tuple[int, ...], List[int]] = (8, 4, 8),\n                       verbose: bool = False,\n                       show_progress_bar: bool = True) -> None:\n    if not isinstance(num_processes, list):\n        num_processes = list(num_processes)\n    if len(num_processes) == 1:\n        num_processes = num_processes * len(configurations)\n    if len(num_processes) != len(configurations):\n        raise RuntimeError(\n            f'The list provided with num_processes must either have len 1 or as many elements as there are '\n            f'configurations (see --help). Number of configurations: {len(configurations)}, length '\n            f'of num_processes: '\n            f'{len(num_processes)}')\n\n    dataset_name = convert_id_to_dataset_name(dataset_id)\n    print(f'Preprocessing dataset {dataset_name}')\n    plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json')\n    plans_manager = PlansManager(plans_file)\n    for n, c in zip(num_processes, configurations):\n        print(f'Configuration: {c}...')\n        if c not in plans_manager.available_configurations:\n            print(\n                f\"INFO: Configuration {c} not found in plans file {plans_identifier + '.json'} of \"\n                f\"dataset {dataset_name}. Skipping.\")\n            continue\n        configuration_manager = plans_manager.get_configuration(c)\n        print(configuration_manager)\n        preprocessor = configuration_manager.preprocessor_class(verbose=verbose)\n        if hasattr(preprocessor, 'show_progress_bar'):\n            preprocessor.show_progress_bar = show_progress_bar\n        preprocessor.run(dataset_id, c, plans_identifier, num_processes=n)\n\n    # copy the gt to a folder in the nnUNet_preprocessed so that we can do validation even if the raw data is no\n    # longer there (useful for compute cluster where only the preprocessed data is available)\n    from distutils.file_util import copy_file\n    maybe_mkdir_p(join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'))\n    dataset_json = load_json(join(nnUNet_raw, dataset_name, 'dataset.json'))\n    dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json)\n    # only copy files that are newer than the ones already present\n    for k in dataset:\n        copy_file(dataset[k]['label'],\n                  join(nnUNet_preprocessed, dataset_name, 'gt_segmentations', k + dataset_json['file_ending']),\n                  update=True)\n\n\ndef preprocess(dataset_ids: List[int],\n               plans_identifier: str = 'nnUNetPlans',\n               configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'),\n               num_processes: Union[int, Tuple[int, ...], List[int]] = (8, 4, 8),\n               verbose: bool = False,\n               show_progress_bar: bool = True):\n    for d in dataset_ids:\n        preprocess_dataset(d, plans_identifier, configurations, num_processes, verbose, show_progress_bar)\n"
  },
  {
    "path": "nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py",
    "content": "from nnunetv2.configuration import default_num_processes\nfrom nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, plan_experiments, preprocess\n\n\ndef _add_logging_args(parser):\n    parser.add_argument('--verbose', required=False, action='store_true',\n                        help='Set this to print a lot of stuff. Useful for debugging.')\n    parser.add_argument('--no_pbar', required=False, action='store_true',\n                        help='Set this flag to disable the progress bar. Recommended for cluster/HPC environments.')\n\n\ndef extract_fingerprint_entry():\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-d', nargs='+', type=int,\n                        help=\"[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment \"\n                             \"planning and preprocessing for these datasets. Can of course also be just one dataset\")\n    parser.add_argument('-fpe', type=str, required=False, default='DatasetFingerprintExtractor',\n                        help='[OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is '\n                             '\\'DatasetFingerprintExtractor\\'.')\n    parser.add_argument('-np', type=int, default=default_num_processes, required=False,\n                        help=f'[OPTIONAL] Number of processes used for fingerprint extraction. '\n                             f'Default: {default_num_processes}')\n    parser.add_argument(\"--verify_dataset_integrity\", required=False, default=False, action=\"store_true\",\n                        help=\"[RECOMMENDED] set this flag to check the dataset integrity. This is useful and should be done once for \"\n                             \"each dataset!\")\n    parser.add_argument(\"--clean\", required=False, default=False, action=\"store_true\",\n                        help='[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a '\n                             'fingerprint already exists, the fingerprint extractor will not run.')\n    _add_logging_args(parser)\n    args, unrecognized_args = parser.parse_known_args()\n    extract_fingerprints(args.d, args.fpe, args.np, args.verify_dataset_integrity, args.clean, args.verbose,\n                         show_progress_bar=not args.no_pbar)\n\n\ndef plan_experiment_entry():\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-d', nargs='+', type=int,\n                        help=\"[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment \"\n                             \"planning and preprocessing for these datasets. Can of course also be just one dataset\")\n    parser.add_argument('-pl', type=str, default='ExperimentPlanner', required=False,\n                        help='[OPTIONAL] Name of the Experiment Planner class that should be used. Default is '\n                             '\\'ExperimentPlanner\\'. Note: There is no longer a distinction between 2d and 3d planner. '\n                             'It\\'s an all in one solution now. Wuch. Such amazing.')\n    parser.add_argument('-gpu_memory_target', default=None, type=float, required=False,\n                        help='[OPTIONAL] DANGER ZONE! Sets a custom GPU memory target (in GB). Default: None (=Planner '\n                             'class default is used). Changing this will '\n                             'affect patch and batch size and will '\n                             'definitely affect your models performance! Only use this if you really know what you '\n                             'are doing and NEVER use this without running the default nnU-Net first as a baseline.')\n    parser.add_argument('-preprocessor_name', default='DefaultPreprocessor', type=str, required=False,\n                        help='[OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in '\n                             'nnunetv2.preprocessing. Default: \\'DefaultPreprocessor\\'. Changing this may affect your '\n                             'models performance! Only use this if you really know what you '\n                             'are doing and NEVER use this without running the default nnU-Net first (as a baseline).')\n    parser.add_argument('-overwrite_target_spacing', default=None, nargs='+', required=False,\n                        help='[OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and 3d_cascade_fullres '\n                             'configurations. Default: None [no changes]. Changing this will affect image size and '\n                             'potentially patch and batch '\n                             'size. This will definitely affect your models performance! Only use this if you really '\n                             'know what you are doing and NEVER use this without running the default nnU-Net first '\n                             '(as a baseline). Changing the target spacing for the other configurations is currently '\n                             'not implemented. New target spacing must be a list of three numbers!')\n    parser.add_argument('-overwrite_plans_name', default=None, required=False,\n                        help='[OPTIONAL] DANGER ZONE! If you used -gpu_memory_target, -preprocessor_name or '\n                             '-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a '\n                             'differently named plans file such that the nnunet default plans are not '\n                             'overwritten. You will then need to specify your custom plans file with -p whenever '\n                             'running other nnunet commands (training, inference etc)')\n    args, unrecognized_args = parser.parse_known_args()\n    plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing,\n                     args.overwrite_plans_name)\n\n\ndef preprocess_entry():\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-d', nargs='+', type=int,\n                        help=\"[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment \"\n                             \"planning and preprocessing for these datasets. Can of course also be just one dataset\")\n    parser.add_argument('-plans_name', default='nnUNetPlans', required=False,\n                        help='[OPTIONAL] You can use this to specify a custom plans file that you may have generated')\n    parser.add_argument('-c', required=False, default=['2d', '3d_fullres', '3d_lowres'], nargs='+',\n                        help='[OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3d_fullres '\n                             '3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data '\n                             'from 3d_fullres. Configurations that do not exist for some dataset will be skipped.')\n    parser.add_argument('-np', type=int, nargs='+', default=None, required=False,\n                        help=\"[OPTIONAL] Use this to define how many processes are to be used. If this is just one number then \"\n                             \"this number of processes is used for all configurations specified with -c. If it's a \"\n                             \"list of numbers this list must have as many elements as there are configurations. We \"\n                             \"then iterate over zip(configs, num_processes) to determine then umber of processes \"\n                             \"used for each configuration. More processes is always faster (up to the number of \"\n                             \"threads your PC can support, so 8 for a 4 core CPU with hyperthreading. If you don't \"\n                             \"know what that is then dont touch it, or at least don't increase it!). DANGER: More \"\n                             \"often than not the number of processes that can be used is limited by the amount of \"\n                             \"RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND \"\n                             \"DECREASE -np IF YOUR RAM FILLS UP TOO MUCH!. Default: 8 processes for 2d, 4 \"\n                             \"for 3d_fullres, 8 for 3d_lowres and 4 for everything else\")\n    _add_logging_args(parser)\n    args, unrecognized_args = parser.parse_known_args()\n    if args.np is None:\n        default_np = {\"2d\": 8, \"3d_fullres\": 4, \"3d_lowres\": 8}\n        np = [default_np[c] if c in default_np.keys() else 4 for c in args.c]\n    else:\n        np = args.np\n    preprocess(args.d, args.plans_name, configurations=args.c, num_processes=np, verbose=args.verbose,\n               show_progress_bar=not args.no_pbar)\n\n\ndef plan_and_preprocess_entry():\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-d', nargs='+', type=int,\n                        help=\"[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment \"\n                             \"planning and preprocessing for these datasets. Can of course also be just one dataset\")\n    parser.add_argument('-fpe', type=str, required=False, default='DatasetFingerprintExtractor',\n                        help='[OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is '\n                             '\\'DatasetFingerprintExtractor\\'.')\n    parser.add_argument('-npfp', type=int, default=8, required=False,\n                        help='[OPTIONAL] Number of processes used for fingerprint extraction. Default: 8')\n    parser.add_argument(\"--verify_dataset_integrity\", required=False, default=False, action=\"store_true\",\n                        help=\"[RECOMMENDED] set this flag to check the dataset integrity. This is useful and should be done once for \"\n                             \"each dataset!\")\n    parser.add_argument('--no_pp', default=False, action='store_true', required=False,\n                        help='[OPTIONAL] Set this to only run fingerprint extraction and experiment planning (no '\n                             'preprocesing). Useful for debugging.')\n    parser.add_argument(\"--clean\", required=False, default=False, action=\"store_true\",\n                        help='[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a '\n                             'fingerprint already exists, the fingerprint extractor will not run. REQUIRED IF YOU '\n                             'CHANGE THE DATASET FINGERPRINT EXTRACTOR OR MAKE CHANGES TO THE DATASET!')\n    parser.add_argument('-pl', type=str, default='ExperimentPlanner', required=False,\n                        help='[OPTIONAL] Name of the Experiment Planner class that should be used. Default is '\n                             '\\'ExperimentPlanner\\'. Note: There is no longer a distinction between 2d and 3d planner. '\n                             'It\\'s an all in one solution now. Wuch. Such amazing.')\n    parser.add_argument('-gpu_memory_target', default=None, type=float, required=False,\n                        help='[OPTIONAL] DANGER ZONE! Sets a custom GPU memory target (in GB). Default: None (=Planner '\n                             'class default is used). Changing this will '\n                             'affect patch and batch size and will '\n                             'definitely affect your models performance! Only use this if you really know what you '\n                             'are doing and NEVER use this without running the default nnU-Net first as a baseline.')\n    parser.add_argument('-preprocessor_name', default='DefaultPreprocessor', type=str, required=False,\n                        help='[OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in '\n                             'nnunetv2.preprocessing. Default: \\'DefaultPreprocessor\\'. Changing this may affect your '\n                             'models performance! Only use this if you really know what you '\n                             'are doing and NEVER use this without running the default nnU-Net first (as a baseline).')\n    parser.add_argument('-overwrite_target_spacing', default=None, nargs='+', required=False,\n                        help='[OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and 3d_cascade_fullres '\n                             'configurations. Default: None [no changes]. Changing this will affect image size and '\n                             'potentially patch and batch '\n                             'size. This will definitely affect your models performance! Only use this if you really '\n                             'know what you are doing and NEVER use this without running the default nnU-Net first '\n                             '(as a baseline). Changing the target spacing for the other configurations is currently '\n                             'not implemented. New target spacing must be a list of three numbers!')\n    parser.add_argument('-overwrite_plans_name', default=None, required=False,\n                        help='[OPTIONAL] uSE A CUSTOM PLANS IDENTIFIER. If you used -gpu_memory_target, '\n                             '-preprocessor_name or '\n                             '-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a '\n                             'differently named plans file such that the nnunet default plans are not '\n                             'overwritten. You will then need to specify your custom plans file with -p whenever '\n                             'running other nnunet commands (training, inference etc)')\n    parser.add_argument('-c', required=False, default=['2d', '3d_fullres', '3d_lowres'], nargs='+',\n                        help='[OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3d_fullres '\n                             '3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data '\n                             'from 3d_fullres. Configurations that do not exist for some dataset will be skipped.')\n    parser.add_argument('-np', type=int, nargs='+', default=None, required=False,\n                        help=\"[OPTIONAL] Use this to define how many processes are to be used. If this is just one number then \"\n                             \"this number of processes is used for all configurations specified with -c. If it's a \"\n                             \"list of numbers this list must have as many elements as there are configurations. We \"\n                             \"then iterate over zip(configs, num_processes) to determine then umber of processes \"\n                             \"used for each configuration. More processes is always faster (up to the number of \"\n                             \"threads your PC can support, so 8 for a 4 core CPU with hyperthreading. If you don't \"\n                             \"know what that is then dont touch it, or at least don't increase it!). DANGER: More \"\n                             \"often than not the number of processes that can be used is limited by the amount of \"\n                             \"RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND \"\n                             \"DECREASE -np IF YOUR RAM FILLS UP TOO MUCH!. Default: 8 processes for 2d, 4 \"\n                             \"for 3d_fullres, 8 for 3d_lowres and 4 for everything else\")\n    _add_logging_args(parser)\n    args = parser.parse_args()\n\n    # fingerprint extraction\n    print(\"Fingerprint extraction...\")\n    extract_fingerprints(args.d, args.fpe, args.npfp, args.verify_dataset_integrity, args.clean, args.verbose,\n                         show_progress_bar=not args.no_pbar)\n\n    # experiment planning\n    print('Experiment planning...')\n    plans_identifier = plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name,\n                                        args.overwrite_target_spacing, args.overwrite_plans_name)\n\n    # manage default np\n    if args.np is None:\n        default_np = {\"2d\": 8, \"3d_fullres\": 4, \"3d_lowres\": 8}\n        np = [default_np[c] if c in default_np.keys() else 4 for c in args.c]\n    else:\n        np = args.np\n    # preprocessing\n    if not args.no_pp:\n        print('Preprocessing...')\n        preprocess(args.d, plans_identifier, args.c, np, args.verbose, show_progress_bar=not args.no_pbar)\n\n\nif __name__ == '__main__':\n    plan_and_preprocess_entry()\n"
  },
  {
    "path": "nnunetv2/experiment_planning/plans_for_pretraining/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py",
    "content": "import argparse\nfrom typing import Union\n\nfrom batchgenerators.utilities.file_and_folder_operations import join, isdir, isfile, load_json, save_json\n\nfrom nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json\nfrom nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw\nfrom nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name\nfrom nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets\n\n\ndef move_plans_between_datasets(\n        source_dataset_name_or_id: Union[int, str],\n        target_dataset_name_or_id: Union[int, str],\n        source_plans_identifier: str,\n        target_plans_identifier: str = None):\n    source_dataset_name = maybe_convert_to_dataset_name(source_dataset_name_or_id)\n    target_dataset_name = maybe_convert_to_dataset_name(target_dataset_name_or_id)\n\n    if target_plans_identifier is None:\n        target_plans_identifier = source_plans_identifier\n\n    source_folder = join(nnUNet_preprocessed, source_dataset_name)\n    assert isdir(source_folder), f\"Cannot move plans because preprocessed directory of source dataset is missing. \" \\\n                                 f\"Run nnUNetv2_plan_and_preprocess for source dataset first!\"\n\n    source_plans_file = join(source_folder, source_plans_identifier + '.json')\n    assert isfile(source_plans_file), f\"Source plans are missing. Run the corresponding experiment planning first! \" \\\n                                      f\"Expected file: {source_plans_file}\"\n\n    source_plans = load_json(source_plans_file)\n    source_plans['dataset_name'] = target_dataset_name\n\n    # we need to change data_identifier to use target_plans_identifier\n    if target_plans_identifier != source_plans_identifier:\n        for c in source_plans['configurations'].keys():\n            if 'data_identifier' in source_plans['configurations'][c].keys():\n                old_identifier = source_plans['configurations'][c][\"data_identifier\"]\n                if old_identifier.startswith(source_plans_identifier):\n                    new_identifier = target_plans_identifier + old_identifier[len(source_plans_identifier):]\n                else:\n                    new_identifier = target_plans_identifier + '_' + old_identifier\n                source_plans['configurations'][c][\"data_identifier\"] = new_identifier\n\n    # we need to change the reader writer class!\n    target_raw_data_dir = join(nnUNet_raw, target_dataset_name)\n    target_dataset_json = load_json(join(target_raw_data_dir, 'dataset.json'))\n\n    # we may need to change the reader/writer\n    # pick any file from the source dataset\n    dataset = get_filenames_of_train_images_and_targets(target_raw_data_dir, target_dataset_json)\n    example_image = dataset[dataset.keys().__iter__().__next__()]['images'][0]\n    rw = determine_reader_writer_from_dataset_json(target_dataset_json, example_image, allow_nonmatching_filename=True,\n                                                   verbose=False)\n\n    source_plans[\"image_reader_writer\"] = rw.__name__\n    if target_plans_identifier is not None:\n        source_plans[\"plans_name\"] = target_plans_identifier\n\n    save_json(source_plans, join(nnUNet_preprocessed, target_dataset_name, target_plans_identifier + '.json'),\n              sort_keys=False)\n\n\ndef entry_point_move_plans_between_datasets():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-s', type=str, required=True,\n                        help='Source dataset name or id')\n    parser.add_argument('-t', type=str, required=True,\n                        help='Target dataset name or id')\n    parser.add_argument('-sp', type=str, required=True,\n                        help='Source plans identifier. If your plans are named \"nnUNetPlans.json\" then the '\n                             'identifier would be nnUNetPlans')\n    parser.add_argument('-tp', type=str, required=False, default=None,\n                        help='Target plans identifier. Default is None meaning the source plans identifier will '\n                             'be kept. Not recommended if the source plans identifier is a default nnU-Net identifier '\n                             'such as nnUNetPlans!!!')\n    args = parser.parse_args()\n    move_plans_between_datasets(args.s, args.t, args.sp, args.tp)\n\n\nif __name__ == '__main__':\n    move_plans_between_datasets(2, 4, 'nnUNetPlans', 'nnUNetPlansFrom2')\n"
  },
  {
    "path": "nnunetv2/experiment_planning/verify_dataset_integrity.py",
    "content": "#    Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center\n#    (DKFZ), Heidelberg, Germany\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\nimport multiprocessing\nfrom typing import Type\n\nimport numpy as np\nimport pandas as pd\nfrom batchgenerators.utilities.file_and_folder_operations import *\n\nfrom nnunetv2.imageio.base_reader_writer import BaseReaderWriter\nfrom nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json\nfrom nnunetv2.paths import nnUNet_raw\nfrom nnunetv2.utilities.label_handling.label_handling import LabelManager\nfrom nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets\n\n\ndef verify_labels(label_file: str, readerclass: Type[BaseReaderWriter], expected_labels: List[int]) -> bool:\n    rw = readerclass()\n    seg, properties = rw.read_seg(label_file)\n    found_labels = np.sort(pd.unique(seg.ravel()))  # np.unique(seg)\n    unexpected_labels = [i for i in found_labels if i not in expected_labels]\n    if len(found_labels) == 0 and found_labels[0] == 0:\n        print('WARNING: File %s only has label 0 (which should be background). This may be intentional or not, '\n              'up to you.' % label_file)\n    if len(unexpected_labels) > 0:\n        print(\"Error: Unexpected labels found in file %s.\\nExpected: %s\\nFound: %s\" % (label_file, expected_labels,\n                                                                                       found_labels))\n        return False\n    return True\n\n\ndef check_cases(image_files: List[str], label_file: str, expected_num_channels: int,\n                readerclass: Type[BaseReaderWriter]) -> bool:\n    rw = readerclass()\n    ret = True\n\n    images, properties_image = rw.read_images(image_files)\n    segmentation, properties_seg = rw.read_seg(label_file)\n\n    # check for nans\n    if np.any(np.isnan(images)):\n        print(f'Images contain NaN pixel values. You need to fix that by '\n              f'replacing NaN values with something that makes sense for your images!\\nImages:\\n{image_files}')\n        ret = False\n    if np.any(np.isnan(segmentation)):\n        print(f'Segmentation contains NaN pixel values. You need to fix that.\\nSegmentation:\\n{label_file}')\n        ret = False\n\n    # check shapes\n    shape_image = images.shape[1:]\n    shape_seg = segmentation.shape[1:]\n    if shape_image != shape_seg:\n        print('Error: Shape mismatch between segmentation and corresponding images. \\nShape images: %s. '\n              '\\nShape seg: %s. \\nImage files: %s. \\nSeg file: %s\\n' %\n              (shape_image, shape_seg, image_files, label_file))\n        ret = False\n\n    # check spacings\n    spacing_images = properties_image['spacing']\n    spacing_seg = properties_seg['spacing']\n    if not np.allclose(spacing_seg, spacing_images):\n        print('Error: Spacing mismatch between segmentation and corresponding images. \\nSpacing images: %s. '\n              '\\nSpacing seg: %s. \\nImage files: %s. \\nSeg file: %s\\n' %\n              (spacing_images, spacing_seg, image_files, label_file))\n        ret = False\n\n    # check modalities\n    if not len(images) == expected_num_channels:\n        print('Error: Unexpected number of modalities. \\nExpected: %d. \\nGot: %d. \\nImages: %s\\n'\n              % (expected_num_channels, len(images), image_files))\n        ret = False\n\n    # nibabel checks\n    if 'nibabel_stuff' in properties_image.keys():\n        # this image was read with NibabelIO\n        affine_image = properties_image['nibabel_stuff']['original_affine']\n        affine_seg = properties_seg['nibabel_stuff']['original_affine']\n        if not np.allclose(affine_image, affine_seg):\n            print('WARNING: Affine is not the same for image and seg! \\nAffine image: %s \\nAffine seg: %s\\n'\n                  'Image files: %s. \\nSeg file: %s.\\nThis can be a problem but doesn\\'t have to be. Please run '\n                  'nnUNetv2_plot_overlay_pngs to verify if everything is OK!\\n'\n                  % (affine_image, affine_seg, image_files, label_file))\n\n    # sitk checks\n    if 'sitk_stuff' in properties_image.keys():\n        # this image was read with SimpleITKIO\n        # spacing has already been checked, only check direction and origin\n        origin_image = properties_image['sitk_stuff']['origin']\n        origin_seg = properties_seg['sitk_stuff']['origin']\n        if not np.allclose(origin_image, origin_seg):\n            print('Warning: Origin mismatch between segmentation and corresponding images. \\nOrigin images: %s. '\n                  '\\nOrigin seg: %s. \\nImage files: %s. \\nSeg file: %s\\n' %\n                  (origin_image, origin_seg, image_files, label_file))\n        direction_image = properties_image['sitk_stuff']['direction']\n        direction_seg = properties_seg['sitk_stuff']['direction']\n        if not np.allclose(direction_image, direction_seg):\n            print('Warning: Direction mismatch between segmentation and corresponding images. \\nDirection images: %s. '\n                  '\\nDirection seg: %s. \\nImage files: %s. \\nSeg file: %s\\n' %\n                  (direction_image, direction_seg, image_files, label_file))\n\n    return ret\n\n\ndef verify_dataset_integrity(folder: str, num_processes: int = 8) -> None:\n    \"\"\"\n    folder needs the imagesTr, imagesTs and labelsTr subfolders. There also needs to be a dataset.json\n    checks if the expected number of training cases and labels are present\n    for each case, if possible, checks whether the pixel grids are aligned\n    checks whether the labels really only contain values they should\n    :param folder:\n    :return:\n    \"\"\"\n    assert isfile(join(folder, \"dataset.json\")), f\"There needs to be a dataset.json file in folder, folder={folder}\"\n    dataset_json = load_json(join(folder, \"dataset.json\"))\n\n    if not 'dataset' in dataset_json.keys():\n        assert isdir(join(folder, \"imagesTr\")), f\"There needs to be a imagesTr subfolder in folder, folder={folder}\"\n        assert isdir(join(folder, \"labelsTr\")), f\"There needs to be a labelsTr subfolder in folder, folder={folder}\"\n\n    # make sure all required keys are there\n    dataset_keys = list(dataset_json.keys())\n    required_keys = ['labels', \"channel_names\", \"numTraining\", \"file_ending\"]\n    assert all([i in dataset_keys for i in required_keys]), 'not all required keys are present in dataset.json.' \\\n                                                            '\\n\\nRequired: \\n%s\\n\\nPresent: \\n%s\\n\\nMissing: ' \\\n                                                            '\\n%s\\n\\nUnused by nnU-Net:\\n%s' % \\\n                                                            (str(required_keys),\n                                                             str(dataset_keys),\n                                                             str([i for i in required_keys if i not in dataset_keys]),\n                                                             str([i for i in dataset_keys if i not in required_keys]))\n\n    expected_num_training = dataset_json['numTraining']\n    num_modalities = len(dataset_json['channel_names'].keys()\n                         if 'channel_names' in dataset_json.keys()\n                         else dataset_json['modality'].keys())\n    file_ending = dataset_json['file_ending']\n\n    dataset = get_filenames_of_train_images_and_targets(folder, dataset_json)\n\n    # check if the right number of training cases is present\n    assert len(dataset) == expected_num_training, 'Did not find the expected number of training cases ' \\\n                                                               '(%d). Found %d instead.\\nExamples: %s' % \\\n                                                               (expected_num_training, len(dataset),\n                                                                list(dataset.keys())[:5])\n\n    # check if corresponding labels are present\n    if 'dataset' in dataset_json.keys():\n        # just check if everything is there\n        ok = True\n        missing_images = []\n        missing_labels = []\n        for k in dataset:\n            for i in dataset[k]['images']:\n                if not isfile(i):\n                    missing_images.append(i)\n                    ok = False\n            if not isfile(dataset[k]['label']):\n                missing_labels.append(dataset[k]['label'])\n                ok = False\n        if not ok:\n            raise FileNotFoundError(f\"Some expected files were missing. Make sure you are properly referencing them \"\n                                    f\"in the dataset.json. Or use imagesTr & labelsTr folders!\\nMissing images:\"\n                                    f\"\\n{missing_images}\\n\\nMissing labels:\\n{missing_labels}\")\n    else:\n        # old code that uses imagestr and labelstr folders\n        labelfiles = subfiles(join(folder, 'labelsTr'), suffix=file_ending, join=False)\n        label_identifiers = [i[:-len(file_ending)] for i in labelfiles]\n        labels_present = [i in label_identifiers for i in dataset.keys()]\n        missing = [i for j, i in enumerate(dataset.keys()) if not labels_present[j]]\n        assert all(labels_present), f'not all training cases have a label file in labelsTr. Fix that. Missing: {missing}'\n\n    labelfiles = [v['label'] for v in dataset.values()]\n    image_files = [v['images'] for v in dataset.values()]\n\n    # no plans exist yet, so we can't use PlansManager and gotta roll with the default. It's unlikely to cause\n    # problems anyway\n    label_manager = LabelManager(dataset_json['labels'], regions_class_order=dataset_json.get('regions_class_order'))\n    expected_labels = label_manager.all_labels\n    if label_manager.has_ignore_label:\n        expected_labels.append(label_manager.ignore_label)\n    labels_valid_consecutive = np.ediff1d(expected_labels) == 1\n    assert all(\n        labels_valid_consecutive), f'Labels must be in consecutive order (0, 1, 2, ...). The labels {np.array(expected_labels)[1:][~labels_valid_consecutive]} do not satisfy this restriction'\n\n    # determine reader/writer class\n    reader_writer_class = determine_reader_writer_from_dataset_json(dataset_json, dataset[dataset.keys().__iter__().__next__()]['images'][0])\n\n    # check whether only the desired labels are present\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as p:\n        result = p.starmap(\n            verify_labels,\n            zip(labelfiles, [reader_writer_class] * len(labelfiles), [expected_labels] * len(labelfiles))\n        )\n        if not all(result):\n            raise RuntimeError(\n                'Some segmentation images contained unexpected labels. Please check text output above to see which one(s).')\n\n        # check whether shapes and spacings match between images and labels\n        result = p.starmap(\n            check_cases,\n            zip(image_files, labelfiles, [num_modalities] * expected_num_training,\n                [reader_writer_class] * expected_num_training)\n        )\n        if not all(result):\n            raise RuntimeError(\n                'Some images have errors. Please check text output above to see which one(s) and what\\'s going on.')\n\n    # check for nans\n    # check all same orientation nibabel\n    print('\\n####################')\n    print('verify_dataset_integrity Done. \\nIf you didn\\'t see any error messages then your dataset is most likely OK!')\n    print('####################\\n')\n\n\nif __name__ == \"__main__\":\n    # investigate geometry issues\n    example_folder = join(nnUNet_raw, 'Dataset250_COMPUTING_it0')\n    num_processes = 6\n    verify_dataset_integrity(example_folder, num_processes)\n"
  },
  {
    "path": "nnunetv2/imageio/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/imageio/base_reader_writer.py",
    "content": "#    Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center\n#    (DKFZ), Heidelberg, Germany\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\nfrom abc import ABC, abstractmethod\nfrom typing import Tuple, Union, List\nimport numpy as np\n\n\nclass BaseReaderWriter(ABC):\n    @staticmethod\n    def _check_all_same(input_list):\n        if len(input_list) == 1:\n            return True\n        else:\n            # compare all entries to the first\n            return np.allclose(input_list[0], input_list[1:])\n\n    @staticmethod\n    def _check_all_same_array(input_list):\n        # compare all entries to the first\n        for i in input_list[1:]:\n            if i.shape != input_list[0].shape or not np.allclose(i, input_list[0]):\n                return False\n        return True\n\n    @abstractmethod\n    def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:\n        \"\"\"\n        Reads a sequence of images and returns a 4d (!) np.ndarray along with a dictionary. The 4d array must have the\n        modalities (or color channels, or however you would like to call them) in its first axis, followed by the\n        spatial dimensions (so shape must be c,x,y,z where c is the number of modalities (can be 1)).\n        Use the dictionary to store necessary meta information that is lost when converting to numpy arrays, for\n        example the Spacing, Orientation and Direction of the image. This dictionary will be handed over to write_seg\n        for exporting the predicted segmentations, so make sure you have everything you need in there!\n\n        IMPORTANT: dict MUST have a 'spacing' key with a tuple/list of length 3 with the voxel spacing of the np.ndarray.\n        Example: my_dict = {'spacing': (3, 0.5, 0.5), ...}. This is needed for planning and\n        preprocessing. The ordering of the numbers must correspond to the axis ordering in the returned numpy array. So\n        if the array has shape c,x,y,z and the spacing is (a,b,c) then a must be the spacing of x, b the spacing of y\n        and c the spacing of z.\n\n        In the case of 2D images, the returned array should have shape (c, 1, x, y) and the spacing should be\n        (999, sp_x, sp_y). Make sure 999 is larger than sp_x and sp_y! Example: shape=(3, 1, 224, 224),\n        spacing=(999, 1, 1)\n\n        For images that don't have a spacing, set the spacing to 1 (2d exception with 999 for the first axis still applies!)\n\n        :param image_fnames:\n        :return:\n            1) a np.ndarray of shape (c, x, y, z) where c is the number of image channels (can be 1) and x, y, z are\n            the spatial dimensions (set x=1 for 2D! Example: (3, 1, 224, 224) for RGB image).\n            2) a dictionary with metadata. This can be anything. BUT it HAS to include a {'spacing': (a, b, c)} where a\n            is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set\n            a=999 (largest spacing value! Make it larger than b and c)\n\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:\n        \"\"\"\n        Same requirements as BaseReaderWriter.read_image. Returned segmentations must have shape 1,x,y,z. Multiple\n        segmentations are not (yet?) allowed\n\n        If images and segmentations can be read the same way you can just `return self.read_image((image_fname,))`\n        :param seg_fname:\n        :return:\n            1) a np.ndarray of shape (1, x, y, z) where x, y, z are\n            the spatial dimensions (set x=1 for 2D! Example: (1, 1, 224, 224) for 2D segmentation).\n            2) a dictionary with metadata. This can be anything. BUT it HAS to include a {'spacing': (a, b, c)} where a\n            is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set\n            a=999 (largest spacing value! Make it larger than b and c)\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:\n        \"\"\"\n        Export the predicted segmentation to the desired file format. The given seg array will have the same shape and\n        orientation as the corresponding image data, so you don't need to do any resampling or whatever. Just save :-)\n\n        properties is the same dictionary you created during read_images/read_seg so you can use the information here\n        to restore metadata\n\n        IMPORTANT: Segmentations are always 3D! If your input images were 2d then the segmentation will have shape\n        1,x,y. You need to catch that and export accordingly (for 2d images you need to convert the 3d segmentation\n        to 2d via seg = seg[0])!\n\n        :param seg: A segmentation (np.ndarray, integer) of shape (x, y, z). For 2D segmentations this will be (1, y, z)!\n        :param output_fname:\n        :param properties: the dictionary that you created in read_images (the ones this segmentation is based on).\n        Use this to restore metadata\n        :return:\n        \"\"\"\n        pass"
  },
  {
    "path": "nnunetv2/imageio/natural_image_reader_writer.py",
    "content": "#    Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center\n#    (DKFZ), Heidelberg, Germany\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\nfrom typing import Tuple, Union, List\nimport numpy as np\nfrom nnunetv2.imageio.base_reader_writer import BaseReaderWriter\nfrom skimage import io\n\n\nclass NaturalImage2DIO(BaseReaderWriter):\n    \"\"\"\n    ONLY SUPPORTS 2D IMAGES!!!\n    \"\"\"\n\n    # there are surely more we could add here. Everything that can be read by skimage.io should be supported\n    supported_file_endings = [\n        '.png',\n        # '.jpg',\n        # '.jpeg', # jpg not supported because we cannot allow lossy compression! segmentation maps!\n        '.bmp',\n        '.tif'\n    ]\n\n    def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:\n        images = []\n        for f in image_fnames:\n            npy_img = io.imread(f)\n            if npy_img.ndim == 3:\n                # rgb image, last dimension should be the color channel and the size of that channel should be 3\n                # (or 4 if we have alpha)\n                assert npy_img.shape[-1] == 3 or npy_img.shape[-1] == 4, \"If image has three dimensions then the last \" \\\n                                                                         \"dimension must have shape 3 or 4 \" \\\n                                                                         f\"(RGB or RGBA). Image shape here is {npy_img.shape}\"\n                # move RGB(A) to front, add additional dim so that we have shape (c, 1, X, Y), where c is either 3 or 4\n                images.append(npy_img.transpose((2, 0, 1))[:, None])\n            elif npy_img.ndim == 2:\n                # grayscale image\n                images.append(npy_img[None, None])\n\n        if not self._check_all_same([i.shape for i in images]):\n            print('ERROR! Not all input images have the same shape!')\n            print('Shapes:')\n            print([i.shape for i in images])\n            print('Image files:')\n            print(image_fnames)\n            raise RuntimeError()\n        return np.vstack(images, dtype=np.float32, casting='unsafe'), {'spacing': (999, 1, 1)}\n\n    def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:\n        return self.read_images((seg_fname, ))\n\n    def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:\n        io.imsave(output_fname, seg[0].astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False), check_contrast=False)\n\n\nif __name__ == '__main__':\n    images = ('/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/imagesTr/img-11_0000.png',)\n    segmentation = '/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/labelsTr/img-11.png'\n    imgio = NaturalImage2DIO()\n    img, props = imgio.read_images(images)\n    seg, segprops = imgio.read_seg(segmentation)"
  },
  {
    "path": "nnunetv2/imageio/nibabel_reader_writer.py",
    "content": "#    Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center\n#    (DKFZ), Heidelberg, Germany\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\nimport warnings\nfrom typing import Tuple, Union, List\nimport numpy as np\nfrom nibabel.orientations import io_orientation, axcodes2ornt, ornt_transform\n\nfrom nnunetv2.imageio.base_reader_writer import BaseReaderWriter\nimport nibabel\n\nfrom nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO\n\n\nclass NibabelIO(BaseReaderWriter):\n    \"\"\"\n    Nibabel loads the images in a different order than sitk. We convert the axes to the sitk order to be\n    consistent. This is of course considered properly in segmentation export as well.\n\n    IMPORTANT: Run nnUNetv2_plot_overlay_pngs to verify that this did not destroy the alignment of data and seg!\n    \"\"\"\n    supported_file_endings = [\n        '.nii',\n        '.nii.gz',\n    ]\n\n    def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:\n        images = []\n        original_affines = []\n\n        spacings_for_nnunet = []\n        for f in image_fnames:\n            nib_image = nibabel.load(f)\n            assert nib_image.ndim == 3, 'only 3d images are supported by NibabelIO'\n            original_affine = nib_image.affine\n\n            original_affines.append(original_affine)\n\n            # spacing is taken in reverse order to be consistent with SimpleITK axis ordering (confusing, I know...)\n            spacings_for_nnunet.append(\n                [float(i) for i in nib_image.header.get_zooms()[::-1]]\n            )\n\n            # transpose image to be consistent with the way SimpleITk reads images. Yeah. Annoying.\n            images.append(nib_image.get_fdata().transpose((2, 1, 0))[None])\n\n        if not self._check_all_same([i.shape for i in images]):\n            print('ERROR! Not all input images have the same shape!')\n            print('Shapes:')\n            print([i.shape for i in images])\n            print('Image files:')\n            print(image_fnames)\n            raise RuntimeError()\n        if not self._check_all_same_array(original_affines):\n            print('WARNING! Not all input images have the same original_affines!')\n            print('Affines:')\n            print(original_affines)\n            print('Image files:')\n            print(image_fnames)\n            print(\n                'It is up to you to decide whether that\\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify '\n                'that segmentations and data overlap.')\n        if not self._check_all_same(spacings_for_nnunet):\n            print('ERROR! Not all input images have the same spacing_for_nnunet! This might be caused by them not '\n                  'having the same affine')\n            print('spacings_for_nnunet:')\n            print(spacings_for_nnunet)\n            print('Image files:')\n            print(image_fnames)\n            raise RuntimeError()\n\n        dict = {\n            'nibabel_stuff': {\n                'original_affine': original_affines[0],\n            },\n            'spacing': spacings_for_nnunet[0]\n        }\n        return np.vstack(images, dtype=np.float32, casting='unsafe'), dict\n\n    def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:\n        return self.read_images((seg_fname,))\n\n    def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:\n        # revert transpose\n        seg = seg.transpose((2, 1, 0)).astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False)\n        seg_nib = nibabel.Nifti1Image(seg, affine=properties['nibabel_stuff']['original_affine'])\n        nibabel.save(seg_nib, output_fname)\n\n\nclass NibabelIOWithReorient(BaseReaderWriter):\n    \"\"\"\n    Reorients images to RAS\n\n    Nibabel loads the images in a different order than sitk. We convert the axes to the sitk order to be\n    consistent. This is of course considered properly in segmentation export as well.\n\n    IMPORTANT: Run nnUNetv2_plot_overlay_pngs to verify that this did not destroy the alignment of data and seg!\n    \"\"\"\n    supported_file_endings = [\n        '.nii',\n        '.nii.gz',\n    ]\n\n    def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:\n        images = []\n        original_affines = []\n        reoriented_affines = []\n\n        spacings_for_nnunet = []\n        for f in image_fnames:\n            nib_image = nibabel.load(f)\n            assert nib_image.ndim == 3, 'only 3d images are supported by NibabelIO'\n            original_affine = nib_image.affine\n            reoriented_image = nib_image.as_reoriented(io_orientation(original_affine))\n            reoriented_affine = reoriented_image.affine\n\n            original_affines.append(original_affine)\n            reoriented_affines.append(reoriented_affine)\n\n            # spacing is taken in reverse order to be consistent with SimpleITK axis ordering (confusing, I know...)\n            spacings_for_nnunet.append(\n                [float(i) for i in reoriented_image.header.get_zooms()[::-1]]\n            )\n\n            # transpose image to be consistent with the way SimpleITk reads images. Yeah. Annoying.\n            images.append(reoriented_image.get_fdata().transpose((2, 1, 0))[None])\n\n        if not self._check_all_same([i.shape for i in images]):\n            print('ERROR! Not all input images have the same shape!')\n            print('Shapes:')\n            print([i.shape for i in images])\n            print('Image files:')\n            print(image_fnames)\n            raise RuntimeError()\n        if not self._check_all_same_array(reoriented_affines):\n            print('WARNING! Not all input images have the same reoriented_affines!')\n            print('Affines:')\n            print(reoriented_affines)\n            print('Image files:')\n            print(image_fnames)\n            print(\n                'It is up to you to decide whether that\\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify '\n                'that segmentations and data overlap.')\n        if not self._check_all_same(spacings_for_nnunet):\n            print('ERROR! Not all input images have the same spacing_for_nnunet! This might be caused by them not '\n                  'having the same affine')\n            print('spacings_for_nnunet:')\n            print(spacings_for_nnunet)\n            print('Image files:')\n            print(image_fnames)\n            raise RuntimeError()\n\n        dict = {\n            'nibabel_stuff': {\n                'original_affine': original_affines[0],\n                'reoriented_affine': reoriented_affines[0],\n            },\n            'spacing': spacings_for_nnunet[0]\n        }\n        return np.vstack(images, dtype=np.float32, casting='unsafe'), dict\n\n    def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:\n        return self.read_images((seg_fname,))\n\n    def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:\n        # revert transpose\n        seg = seg.transpose((2, 1, 0)).astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False)\n\n        seg_nib = nibabel.Nifti1Image(seg, affine=properties['nibabel_stuff']['reoriented_affine'])\n        # Solution from https://github.com/nipy/nibabel/issues/1063#issuecomment-967124057\n        img_ornt = io_orientation(properties['nibabel_stuff']['original_affine'])\n        ras_ornt = axcodes2ornt(\"RAS\")\n        from_canonical = ornt_transform(ras_ornt, img_ornt)\n        seg_nib_reoriented = seg_nib.as_reoriented(from_canonical)\n        if not np.allclose(properties['nibabel_stuff']['original_affine'], seg_nib_reoriented.affine):\n            print(f'WARNING: Restored affine does not match original affine. File: {output_fname}')\n            print(f'Original affine\\n', properties['nibabel_stuff']['original_affine'])\n            print(f'Restored affine\\n', seg_nib_reoriented.affine)\n        nibabel.save(seg_nib_reoriented, output_fname)\n\n\nif __name__ == '__main__':\n    img_file = '/media/isensee/raw_data/nnUNet_raw/Dataset220_KiTS2023/imagesTr/case_00004_0000.nii.gz'\n    seg_file = '/media/isensee/raw_data/nnUNet_raw/Dataset220_KiTS2023/labelsTr/case_00004.nii.gz'\n\n    nibio = NibabelIO()\n    # images, dct = nibio.read_images([img_file])\n    seg, dctseg = nibio.read_seg(seg_file)\n\n    nibio_r = NibabelIOWithReorient()\n    # images_r, dct_r = nibio_r.read_images([img_file])\n    seg_r, dctseg_r = nibio_r.read_seg(seg_file)\n\n    sitkio = SimpleITKIO()\n    # images_sitk, dct_sitk = sitkio.read_images([img_file])\n    seg_sitk, dctseg_sitk = sitkio.read_seg(seg_file)\n\n    # write reoriented and original segmentation\n    nibio.write_seg(seg[0], '/home/isensee/seg_nibio.nii.gz', dctseg)\n    nibio_r.write_seg(seg_r[0], '/home/isensee/seg_nibio_r.nii.gz', dctseg_r)\n    sitkio.write_seg(seg_sitk[0], '/home/isensee/seg_nibio_sitk.nii.gz', dctseg_sitk)\n\n    # now load all with sitk to make sure no shaped got f'd up\n    a, d1 = sitkio.read_seg('/home/isensee/seg_nibio.nii.gz')\n    b, d2 = sitkio.read_seg('/home/isensee/seg_nibio_r.nii.gz')\n    c, d3 = sitkio.read_seg('/home/isensee/seg_nibio_sitk.nii.gz')\n\n    assert a.shape == b.shape\n    assert b.shape == c.shape\n\n    assert np.all(a == b)\n    assert np.all(b == c)\n"
  },
  {
    "path": "nnunetv2/imageio/reader_writer_registry.py",
    "content": "import traceback\nfrom typing import Type\n\nfrom batchgenerators.utilities.file_and_folder_operations import join\n\nimport nnunetv2\nfrom nnunetv2.imageio.natural_image_reader_writer import NaturalImage2DIO\nfrom nnunetv2.imageio.nibabel_reader_writer import NibabelIO, NibabelIOWithReorient\nfrom nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO\nfrom nnunetv2.imageio.tif_reader_writer import Tiff3DIO\nfrom nnunetv2.imageio.base_reader_writer import BaseReaderWriter\nfrom nnunetv2.utilities.find_class_by_name import recursive_find_python_class\n\nLIST_OF_IO_CLASSES = [\n    NaturalImage2DIO,\n    SimpleITKIO,\n    Tiff3DIO,\n    NibabelIO,\n    NibabelIOWithReorient\n]\n\n\ndef determine_reader_writer_from_dataset_json(dataset_json_content: dict, example_file: str = None,\n                                              allow_nonmatching_filename: bool = False, verbose: bool = True\n                                              ) -> Type[BaseReaderWriter]:\n    if 'overwrite_image_reader_writer' in dataset_json_content.keys() and \\\n            dataset_json_content['overwrite_image_reader_writer'] != 'None':\n        ioclass_name = dataset_json_content['overwrite_image_reader_writer']\n        # trying to find that class in the nnunetv2.imageio module\n        try:\n            ret = recursive_find_reader_writer_by_name(ioclass_name)\n            if verbose: print(f'Using {ret} reader/writer')\n            return ret\n        except RuntimeError:\n            if verbose: print(f'Warning: Unable to find ioclass specified in dataset.json: {ioclass_name}')\n            if verbose: print('Trying to automatically determine desired class')\n    return determine_reader_writer_from_file_ending(dataset_json_content['file_ending'], example_file,\n                                                    allow_nonmatching_filename, verbose)\n\n\ndef determine_reader_writer_from_file_ending(file_ending: str, example_file: str = None, allow_nonmatching_filename: bool = False,\n                                             verbose: bool = True):\n    for rw in LIST_OF_IO_CLASSES:\n        if file_ending.lower() in rw.supported_file_endings:\n            if example_file is not None:\n                # if an example file is provided, try if we can actually read it. If not move on to the next reader\n                try:\n                    tmp = rw()\n                    _ = tmp.read_images((example_file,))\n                    if verbose: print(f'Using {rw} as reader/writer')\n                    return rw\n                except:\n                    if verbose: print(f'Failed to open file {example_file} with reader {rw}:')\n                    traceback.print_exc()\n                    pass\n            else:\n                if verbose: print(f'Using {rw} as reader/writer')\n                return rw\n        else:\n            if allow_nonmatching_filename and example_file is not None:\n                try:\n                    tmp = rw()\n                    _ = tmp.read_images((example_file,))\n                    if verbose: print(f'Using {rw} as reader/writer')\n                    return rw\n                except:\n                    if verbose: print(f'Failed to open file {example_file} with reader {rw}:')\n                    if verbose: traceback.print_exc()\n                    pass\n    raise RuntimeError(f\"Unable to determine a reader for file ending {file_ending} and file {example_file} (file None means no file provided).\")\n\n\ndef recursive_find_reader_writer_by_name(rw_class_name: str) -> Type[BaseReaderWriter]:\n    ret = recursive_find_python_class(join(nnunetv2.__path__[0], \"imageio\"), rw_class_name, 'nnunetv2.imageio')\n    if ret is None:\n        raise RuntimeError(\"Unable to find reader writer class '%s'. Please make sure this class is located in the \"\n                           \"nnunetv2.imageio module.\" % rw_class_name)\n    else:\n        return ret\n"
  },
  {
    "path": "nnunetv2/imageio/readme.md",
    "content": "- Derive your adapter from `BaseReaderWriter`. \n- Reimplement all abstractmethods. \n- make sure to support 2d and 3d input images (or raise some error).\n- place it in this folder or nnU-Net won't find it!\n- add it to LIST_OF_IO_CLASSES in `reader_writer_registry.py`\n\nBam, you're done!"
  },
  {
    "path": "nnunetv2/imageio/simpleitk_reader_writer.py",
    "content": "#    Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center\n#    (DKFZ), Heidelberg, Germany\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\nfrom typing import Tuple, Union, List\nimport numpy as np\nfrom nnunetv2.imageio.base_reader_writer import BaseReaderWriter\nimport SimpleITK as sitk\n\n\nclass SimpleITKIO(BaseReaderWriter):\n    supported_file_endings = [\n        '.nii.gz',\n        '.nrrd',\n        '.mha',\n        '.gipl'\n    ]\n\n    def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:\n        images = []\n        spacings = []\n        origins = []\n        directions = []\n\n        spacings_for_nnunet = []\n        for f in image_fnames:\n            itk_image = sitk.ReadImage(f)\n            spacings.append(itk_image.GetSpacing())\n            origins.append(itk_image.GetOrigin())\n            directions.append(itk_image.GetDirection())\n            npy_image = sitk.GetArrayFromImage(itk_image)\n            if npy_image.ndim == 2:\n                # 2d\n                npy_image = npy_image[None, None]\n                max_spacing = max(spacings[-1])\n                spacings_for_nnunet.append((max_spacing * 999, *list(spacings[-1])[::-1]))\n            elif npy_image.ndim == 3:\n                # 3d, as in original nnunet\n                npy_image = npy_image[None]\n                spacings_for_nnunet.append(list(spacings[-1])[::-1])\n            elif npy_image.ndim == 4:\n                # 4d, multiple modalities in one file\n                spacings_for_nnunet.append(list(spacings[-1])[::-1][1:])\n                pass\n            else:\n                raise RuntimeError(f\"Unexpected number of dimensions: {npy_image.ndim} in file {f}\")\n\n            images.append(npy_image)\n            spacings_for_nnunet[-1] = list(np.abs(spacings_for_nnunet[-1]))\n\n        if not self._check_all_same([i.shape for i in images]):\n            print('ERROR! Not all input images have the same shape!')\n            print('Shapes:')\n            print([i.shape for i in images])\n            print('Image files:')\n            print(image_fnames)\n            raise RuntimeError()\n        if not self._check_all_same(spacings):\n            print('ERROR! Not all input images have the same spacing!')\n            print('Spacings:')\n            print(spacings)\n            print('Image files:')\n            print(image_fnames)\n            raise RuntimeError()\n        if not self._check_all_same(origins):\n            print('WARNING! Not all input images have the same origin!')\n            print('Origins:')\n            print(origins)\n            print('Image files:')\n            print(image_fnames)\n            print('It is up to you to decide whether that\\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify '\n                  'that segmentations and data overlap.')\n        if not self._check_all_same(directions):\n            print('WARNING! Not all input images have the same direction!')\n            print('Directions:')\n            print(directions)\n            print('Image files:')\n            print(image_fnames)\n            print('It is up to you to decide whether that\\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify '\n                  'that segmentations and data overlap.')\n        if not self._check_all_same(spacings_for_nnunet):\n            print('ERROR! Not all input images have the same spacing_for_nnunet! (This should not happen and must be a '\n                  'bug. Please report!')\n            print('spacings_for_nnunet:')\n            print(spacings_for_nnunet)\n            print('Image files:')\n            print(image_fnames)\n            raise RuntimeError()\n\n        dict = {\n            'sitk_stuff': {\n                # this saves the sitk geometry information. This part is NOT used by nnU-Net!\n                'spacing': spacings[0],\n                'origin': origins[0],\n                'direction': directions[0]\n            },\n            # the spacing is inverted with [::-1] because sitk returns the spacing in the wrong order lol. Image arrays\n            # are returned x,y,z but spacing is returned z,y,x. Duh.\n            'spacing': spacings_for_nnunet[0]\n        }\n        return np.vstack(images, dtype=np.float32, casting='unsafe'), dict\n\n    def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:\n        return self.read_images((seg_fname, ))\n\n    def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:\n        assert seg.ndim == 3, 'segmentation must be 3d. If you are exporting a 2d segmentation, please provide it as shape 1,x,y'\n        output_dimension = len(properties['sitk_stuff']['spacing'])\n        assert 1 < output_dimension < 4\n        if output_dimension == 2:\n            seg = seg[0]\n\n        itk_image = sitk.GetImageFromArray(seg.astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False))\n        itk_image.SetSpacing(properties['sitk_stuff']['spacing'])\n        itk_image.SetOrigin(properties['sitk_stuff']['origin'])\n        itk_image.SetDirection(properties['sitk_stuff']['direction'])\n\n        sitk.WriteImage(itk_image, output_fname, True)\n\n\nclass SimpleITKIOWithReorient(SimpleITKIO):\n    def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]], orientation = \"RAS\") -> Tuple[np.ndarray, dict]:\n        images = []\n        spacings = []\n        origins = []\n        directions = []\n\n        spacings_for_nnunet = []\n        for f in image_fnames:\n            itk_image = sitk.ReadImage(f)\n            original_orientation = sitk.DICOMOrientImageFilter_GetOrientationFromDirectionCosines(itk_image.GetDirection())\n            itk_image = sitk.DICOMOrient(itk_image, orientation)\n            # print(sitk.DICOMOrientImageFilter_GetOrientationFromDirectionCosines(itk_image.GetDirection()))\n            spacings.append(itk_image.GetSpacing())\n            origins.append(itk_image.GetOrigin())\n            directions.append(itk_image.GetDirection())\n            npy_image = sitk.GetArrayFromImage(itk_image)\n            if npy_image.ndim == 2:\n                # 2d\n                npy_image = npy_image[None, None]\n                max_spacing = max(spacings[-1])\n                spacings_for_nnunet.append((max_spacing * 999, *list(spacings[-1])[::-1]))\n            elif npy_image.ndim == 3:\n                # 3d, as in original nnunet\n                npy_image = npy_image[None]\n                spacings_for_nnunet.append(list(spacings[-1])[::-1])\n            elif npy_image.ndim == 4:\n                # 4d, multiple modalities in one file\n                spacings_for_nnunet.append(list(spacings[-1])[::-1][1:])\n                pass\n            else:\n                raise RuntimeError(f\"Unexpected number of dimensions: {npy_image.ndim} in file {f}\")\n\n            images.append(npy_image)\n            spacings_for_nnunet[-1] = list(np.abs(spacings_for_nnunet[-1]))\n\n        if not self._check_all_same([i.shape for i in images]):\n            print('ERROR! Not all input images have the same shape!')\n            print('Shapes:')\n            print([i.shape for i in images])\n            print('Image files:')\n            print(image_fnames)\n            raise RuntimeError()\n        if not self._check_all_same(spacings):\n            print('ERROR! Not all input images have the same spacing!')\n            print('Spacings:')\n            print(spacings)\n            print('Image files:')\n            print(image_fnames)\n            raise RuntimeError()\n        if not self._check_all_same(origins):\n            print('WARNING! Not all input images have the same origin!')\n            print('Origins:')\n            print(origins)\n            print('Image files:')\n            print(image_fnames)\n            print('It is up to you to decide whether that\\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify '\n                  'that segmentations and data overlap.')\n        if not self._check_all_same(directions):\n            print('WARNING! Not all input images have the same direction!')\n            print('Directions:')\n            print(directions)\n            print('Image files:')\n            print(image_fnames)\n            print('It is up to you to decide whether that\\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify '\n                  'that segmentations and data overlap.')\n        if not self._check_all_same(spacings_for_nnunet):\n            print('ERROR! Not all input images have the same spacing_for_nnunet! (This should not happen and must be a '\n                  'bug. Please report!')\n            print('spacings_for_nnunet:')\n            print(spacings_for_nnunet)\n            print('Image files:')\n            print(image_fnames)\n            raise RuntimeError()\n\n        dict = {\n            'sitk_stuff': {\n                # this saves the sitk geometry information. This part is NOT used by nnU-Net!\n                'spacing': spacings[0],\n                'origin': origins[0],\n                'direction': directions[0],\n                'original_orientation': original_orientation\n            },\n            # the spacing is inverted with [::-1] because sitk returns the spacing in the wrong order lol. Image arrays\n            # are returned x,y,z but spacing is returned z,y,x. Duh.\n            'spacing': spacings_for_nnunet[0]\n        }\n        return np.vstack(images, dtype=np.float32, casting='unsafe'), dict\n\n    def write_seg(self, seg, output_fname, properties):\n        assert seg.ndim == 3, 'segmentation must be 3d. If you are exporting a 2d segmentation, please provide it as shape 1,x,y'\n        output_dimension = len(properties['sitk_stuff']['spacing'])\n        assert 1 < output_dimension < 4\n        if output_dimension == 2:\n            seg = seg[0]\n\n        itk_image = sitk.GetImageFromArray(seg.astype(np.uint8, copy=False))\n        itk_image.SetSpacing(properties['sitk_stuff']['spacing'])\n        itk_image.SetOrigin(properties['sitk_stuff']['origin'])\n        itk_image.SetDirection(properties['sitk_stuff']['direction'])\n        itk_image = sitk.DICOMOrient(itk_image, properties['sitk_stuff']['original_orientation'])\n\n        sitk.WriteImage(itk_image, output_fname, True)\n"
  },
  {
    "path": "nnunetv2/imageio/tif_reader_writer.py",
    "content": "#    Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center\n#    (DKFZ), Heidelberg, Germany\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\nimport os.path\nfrom typing import Tuple, Union, List\nimport numpy as np\nfrom nnunetv2.imageio.base_reader_writer import BaseReaderWriter\nimport tifffile\nfrom batchgenerators.utilities.file_and_folder_operations import isfile, load_json, save_json, split_path, join\n\n\nclass Tiff3DIO(BaseReaderWriter):\n    \"\"\"\n    reads and writes 3D tif(f) images. Uses tifffile package. Ignores metadata (for now)!\n\n    If you have 2D tiffs, use NaturalImage2DIO\n\n    Supports the use of auxiliary files for spacing information. If used, the auxiliary files are expected to end\n    with .json and omit the channel identifier. So, for example, the corresponding of image image1_0000.tif is\n    expected to be image1.json)!\n    \"\"\"\n    supported_file_endings = [\n        '.tif',\n        '.tiff',\n    ]\n\n    def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:\n        # figure out file ending used here\n        ending = '.' + image_fnames[0].split('.')[-1]\n        assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}'\n        ending_length = len(ending)\n        truncate_length = ending_length + 5 # 5 comes from len(_0000)\n\n        images = []\n        for f in image_fnames:\n            image = tifffile.imread(f)\n            if image.ndim != 3:\n                raise RuntimeError(f\"Only 3D images are supported! File: {f}\")\n            images.append(image[None])\n\n        # see if aux file can be found\n        expected_aux_file = image_fnames[0][:-truncate_length] + '.json'\n        if isfile(expected_aux_file):\n            spacing = load_json(expected_aux_file)['spacing']\n            assert len(spacing) == 3, f'spacing must have 3 entries, one for each dimension of the image. File: {expected_aux_file}'\n        else:\n            print(f'WARNING no spacing file found for images {image_fnames}\\nAssuming spacing (1, 1, 1).')\n            spacing = (1, 1, 1)\n\n        if not self._check_all_same([i.shape for i in images]):\n            print('ERROR! Not all input images have the same shape!')\n            print('Shapes:')\n            print([i.shape for i in images])\n            print('Image files:')\n            print(image_fnames)\n            raise RuntimeError()\n\n        return np.vstack(images, dtype=np.float32, casting='unsafe'), {'spacing': spacing}\n\n    def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:\n        # not ideal but I really have no clue how to set spacing/resolution information properly in tif files haha\n        tifffile.imwrite(output_fname, data=seg.astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False), compression='zlib')\n        file = os.path.basename(output_fname)\n        out_dir = os.path.dirname(output_fname)\n        ending = file.split('.')[-1]\n        save_json({'spacing': properties['spacing']}, join(out_dir, file[:-(len(ending) + 1)] + '.json'))\n\n    def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:\n        # figure out file ending used here\n        ending = '.' + seg_fname.split('.')[-1]\n        assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}'\n        ending_length = len(ending)\n\n        seg = tifffile.imread(seg_fname)\n        if seg.ndim != 3:\n            raise RuntimeError(f\"Only 3D images are supported! File: {seg_fname}\")\n        seg = seg[None]\n\n        # see if aux file can be found\n        expected_aux_file = seg_fname[:-ending_length] + '.json'\n        if isfile(expected_aux_file):\n            spacing = load_json(expected_aux_file)['spacing']\n            assert len(spacing) == 3, f'spacing must have 3 entries, one for each dimension of the image. File: {expected_aux_file}'\n            assert all([i > 0 for i in spacing]), f\"Spacing must be > 0, spacing: {spacing}\"\n        else:\n            print(f'WARNING no spacing file found for segmentation {seg_fname}\\nAssuming spacing (1, 1, 1).')\n            spacing = (1, 1, 1)\n\n        return seg.astype(np.float32, copy=False), {'spacing': spacing}\n"
  },
  {
    "path": "nnunetv2/inference/JHU_inference.py",
    "content": "import argparse\nimport multiprocessing\nimport os\nfrom time import sleep\nfrom typing import Union\n\nimport numpy as np\nimport torch\nfrom batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter\nfrom batchgenerators.utilities.file_and_folder_operations import load_json, save_pickle, join, maybe_mkdir_p, subdirs\n\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.inference.export_prediction import convert_predicted_logits_to_segmentation_with_correct_shape\nfrom nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\nfrom nnunetv2.inference.sliding_window_prediction import compute_gaussian\nfrom nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy\nfrom nnunetv2.utilities.helpers import empty_cache\nfrom nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager\n\n\ndef export_prediction_from_logits_singleFiles(\n        predicted_array_or_file: Union[np.ndarray, torch.Tensor],\n        properties_dict: dict,\n        configuration_manager: ConfigurationManager,\n        plans_manager: PlansManager,\n        dataset_json_dict_or_file: Union[dict, str],\n        output_file_truncated: str,\n        save_probabilities: bool = False):\n    \"\"\"\n    This function generates the output structure expected by the JHU benchmark. We interpret output_file_truncated\n    as the output folder. We create 'predictions' subfolders and populate them with the label maps\n    \"\"\"\n\n    if isinstance(dataset_json_dict_or_file, str):\n        dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)\n\n    label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)\n    ret = convert_predicted_logits_to_segmentation_with_correct_shape(\n        predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict,\n        return_probabilities=save_probabilities\n    )\n    del predicted_array_or_file\n\n    # save\n    if save_probabilities:\n        segmentation_final, probabilities_final = ret\n        np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final)\n        save_pickle(properties_dict, output_file_truncated + '.pkl')\n        del probabilities_final, ret\n    else:\n        segmentation_final = ret\n        del ret\n\n    rw = plans_manager.image_reader_writer_class()\n    output_folder = join(output_file_truncated, 'predictions')\n    maybe_mkdir_p(output_folder)\n    label_name_dict = {j: i for i, j in label_manager.label_dict.items()}\n    for l in label_manager.foreground_labels:\n        label_name = label_name_dict[l]\n        rw.write_seg(\n            (segmentation_final == l).astype(np.uint8, copy=False),\n            join(output_folder, label_name + dataset_json_dict_or_file['file_ending']),\n            properties_dict\n        )\n\n\nclass JHUPredictor(nnUNetPredictor):\n    def predict_from_data_iterator(self,\n                                   data_iterator,\n                                   save_probabilities: bool = False,\n                                   num_processes_segmentation_export: int = default_num_processes):\n        \"\"\"\n        We replace export_prediction_from_logits with export_prediction_from_logits_singleFiles to comply with JHU\n        benchmark output format expectations\n        \"\"\"\n        with multiprocessing.get_context(\"spawn\").Pool(num_processes_segmentation_export) as export_pool:\n            worker_list = [i for i in export_pool._pool]\n            r = []\n            for preprocessed in data_iterator:\n                data = preprocessed['data']\n                if isinstance(data, str):\n                    delfile = data\n                    data = torch.from_numpy(np.load(data))\n                    os.remove(delfile)\n\n                ofile = preprocessed['ofile']\n                if ofile is not None:\n                    print(f'\\nPredicting {os.path.basename(ofile)}:')\n                else:\n                    print(f'\\nPredicting image of shape {data.shape}:')\n\n                print(f'perform_everything_on_device: {self.perform_everything_on_device}')\n\n                properties = preprocessed['data_properties']\n\n                # let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with\n                # npy files\n                proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)\n                while not proceed:\n                    sleep(0.1)\n                    proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)\n\n                prediction = self.predict_logits_from_preprocessed_data(data).cpu()\n\n                if ofile is not None:\n                    # this needs to go into background processes\n                    # export_prediction_from_logits(prediction, properties, self.configuration_manager, self.plans_manager,\n                    #                               self.dataset_json, ofile, save_probabilities)\n                    print('sending off prediction to background worker for resampling and export')\n                    r.append(\n                        export_pool.starmap_async(\n                            export_prediction_from_logits_singleFiles,\n                            ((prediction, properties, self.configuration_manager, self.plans_manager,\n                              self.dataset_json, ofile, save_probabilities),)\n                        )\n                    )\n                else:\n                    # convert_predicted_logits_to_segmentation_with_correct_shape(\n                    #             prediction, self.plans_manager,\n                    #              self.configuration_manager, self.label_manager,\n                    #              properties,\n                    #              save_probabilities)\n\n                    print('sending off prediction to background worker for resampling')\n                    r.append(\n                        export_pool.starmap_async(\n                            convert_predicted_logits_to_segmentation_with_correct_shape, (\n                                (prediction, self.plans_manager,\n                                 self.configuration_manager, self.label_manager,\n                                 properties,\n                                 save_probabilities),)\n                        )\n                    )\n                if ofile is not None:\n                    print(f'done with {os.path.basename(ofile)}')\n                else:\n                    print(f'\\nDone with image of shape {data.shape}:')\n            ret = [i.get()[0] for i in r]\n\n        if isinstance(data_iterator, MultiThreadedAugmenter):\n            data_iterator._finish()\n\n        # clear lru cache\n        compute_gaussian.cache_clear()\n        # clear device cache\n        empty_cache(self.device)\n        return ret\n\n\nif __name__ == '__main__':\n    # python nnunetv2/inference/JHU_inference.py /home/isensee/Downloads/AbdomenAtlasTest /home/isensee/Downloads/AbdomenAtlasTest_pred -model /home/isensee/temp/JHU/trained_model_ep3850\n    # /home/isensee/temp/JHU/trained_model_ep3850\n    # /home/isensee/Downloads/AbdomenAtlasTest\n    # /home/isensee/Downloads/AbdomenAtlasTest_pred\n\n    os.environ['nnUNet_compile'] = 'f'\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('input_dir', type=str)\n    parser.add_argument('output_dir', type=str)\n    parser.add_argument('-model', required=True, type=str)\n    parser.add_argument('--disable_tqdm', required=False, action='store_true', default=False)\n    args = parser.parse_args()\n\n    predictor = JHUPredictor(\n        tile_step_size=0.5,\n        use_gaussian=True,\n        use_mirroring=True,\n        perform_everything_on_device=True,\n        device=torch.device('cuda', 0),\n        verbose=False,\n        verbose_preprocessing=False,\n        allow_tqdm=not args.disable_tqdm\n    )\n\n    predictor.initialize_from_trained_model_folder(\n        args.model,\n        ('all', ),\n        'checkpoint_final.pth'\n    )\n\n    # we need to create list of list of input files\n    input_caseids = subdirs(args.input_dir, join=False)\n    input_files = [[join(args.input_dir, i, 'ct.nii.gz')] for i in input_caseids]\n    output_folders = [join(args.output_dir, i) for i in input_caseids]\n\n    predictor.predict_from_files(\n        input_files,\n        output_folders,\n        save_probabilities=False,\n        overwrite=True,\n        num_processes_preprocessing=2,\n        num_processes_segmentation_export=3,\n        folder_with_segs_from_prev_stage=None,\n        num_parts=1,\n        part_id=0\n    )\n"
  },
  {
    "path": "nnunetv2/inference/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/inference/data_iterators.py",
    "content": "import multiprocessing\nimport queue\nfrom torch.multiprocessing import Event, Queue, Manager\n\nfrom time import sleep\nfrom typing import Union, List\n\nimport numpy as np\nimport torch\nfrom batchgenerators.dataloading.data_loader import DataLoader\n\nfrom nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor\nfrom nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager\n\n\ndef preprocess_fromfiles_save_to_queue(list_of_lists: List[List[str]],\n                                       list_of_segs_from_prev_stage_files: Union[None, List[str]],\n                                       output_filenames_truncated: Union[None, List[str]],\n                                       plans_manager: PlansManager,\n                                       dataset_json: dict,\n                                       configuration_manager: ConfigurationManager,\n                                       target_queue: Queue,\n                                       done_event: Event,\n                                       abort_event: Event,\n                                       verbose: bool = False):\n    try:\n        label_manager = plans_manager.get_label_manager(dataset_json)\n        preprocessor = configuration_manager.preprocessor_class(verbose=verbose)\n        for idx in range(len(list_of_lists)):\n            data, seg, data_properties = preprocessor.run_case(list_of_lists[idx],\n                                                               list_of_segs_from_prev_stage_files[\n                                                                   idx] if list_of_segs_from_prev_stage_files is not None else None,\n                                                               plans_manager,\n                                                               configuration_manager,\n                                                               dataset_json)\n            if list_of_segs_from_prev_stage_files is not None and list_of_segs_from_prev_stage_files[idx] is not None:\n                seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype)\n                data = np.vstack((data, seg_onehot))\n\n            data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)\n\n            item = {'data': data, 'data_properties': data_properties,\n                    'ofile': output_filenames_truncated[idx] if output_filenames_truncated is not None else None}\n            success = False\n            while not success:\n                try:\n                    if abort_event.is_set():\n                        return\n                    target_queue.put(item, timeout=0.01)\n                    success = True\n                except queue.Full:\n                    pass\n        done_event.set()\n    except Exception as e:\n        # print(Exception, e)\n        abort_event.set()\n        raise e\n\n\ndef preprocessing_iterator_fromfiles(list_of_lists: List[List[str]],\n                                     list_of_segs_from_prev_stage_files: Union[None, List[str]],\n                                     output_filenames_truncated: Union[None, List[str]],\n                                     plans_manager: PlansManager,\n                                     dataset_json: dict,\n                                     configuration_manager: ConfigurationManager,\n                                     num_processes: int,\n                                     pin_memory: bool = False,\n                                     verbose: bool = False):\n    context = multiprocessing.get_context('spawn')\n    manager = Manager()\n    num_processes = min(len(list_of_lists), num_processes)\n    assert num_processes >= 1\n    processes = []\n    done_events = []\n    target_queues = []\n    abort_event = manager.Event()\n    for i in range(num_processes):\n        event = manager.Event()\n        queue = manager.Queue(maxsize=1)\n        pr = context.Process(target=preprocess_fromfiles_save_to_queue,\n                     args=(\n                         list_of_lists[i::num_processes],\n                         list_of_segs_from_prev_stage_files[\n                         i::num_processes] if list_of_segs_from_prev_stage_files is not None else None,\n                         output_filenames_truncated[\n                         i::num_processes] if output_filenames_truncated is not None else None,\n                         plans_manager,\n                         dataset_json,\n                         configuration_manager,\n                         queue,\n                         event,\n                         abort_event,\n                         verbose\n                     ), daemon=True)\n        pr.start()\n        target_queues.append(queue)\n        done_events.append(event)\n        processes.append(pr)\n\n    worker_ctr = 0\n    while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()):\n        # import IPython;IPython.embed()\n        if not target_queues[worker_ctr].empty():\n            item = target_queues[worker_ctr].get()\n            worker_ctr = (worker_ctr + 1) % num_processes\n        else:\n            all_ok = all(\n                [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set()\n            if not all_ok:\n                raise RuntimeError('Background workers died. Look for the error message further up! If there is '\n                                   'none then your RAM was full and the worker was killed by the OS. Use fewer '\n                                   'workers or get more RAM in that case!')\n            sleep(0.01)\n            continue\n        if pin_memory:\n            [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)]\n        yield item\n    [p.join() for p in processes]\n\n\nclass PreprocessAdapter(DataLoader):\n    def __init__(self, list_of_lists: List[List[str]],\n                 list_of_segs_from_prev_stage_files: Union[None, List[str]],\n                 preprocessor: DefaultPreprocessor,\n                 output_filenames_truncated: Union[None, List[str]],\n                 plans_manager: PlansManager,\n                 dataset_json: dict,\n                 configuration_manager: ConfigurationManager,\n                 num_threads_in_multithreaded: int = 1):\n        self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json = \\\n            preprocessor, plans_manager, configuration_manager, dataset_json\n\n        self.label_manager = plans_manager.get_label_manager(dataset_json)\n\n        if list_of_segs_from_prev_stage_files is None:\n            list_of_segs_from_prev_stage_files = [None] * len(list_of_lists)\n        if output_filenames_truncated is None:\n            output_filenames_truncated = [None] * len(list_of_lists)\n\n        super().__init__(list(zip(list_of_lists, list_of_segs_from_prev_stage_files, output_filenames_truncated)),\n                         1, num_threads_in_multithreaded,\n                         seed_for_shuffle=1, return_incomplete=True,\n                         shuffle=False, infinite=False, sampling_probabilities=None)\n\n        self.indices = list(range(len(list_of_lists)))\n\n    def generate_train_batch(self):\n        idx = self.get_indices()[0]\n        files, seg_prev_stage, ofile = self._data[idx]\n        # if we have a segmentation from the previous stage we have to process it together with the images so that we\n        # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after\n        # preprocessing and then there might be misalignments\n        data, seg, data_properties = self.preprocessor.run_case(files, seg_prev_stage, self.plans_manager,\n                                                                self.configuration_manager,\n                                                                self.dataset_json)\n        if seg_prev_stage is not None:\n            seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype)\n            data = np.vstack((data, seg_onehot))\n\n        data = torch.from_numpy(data)\n\n        return {'data': data, 'data_properties': data_properties, 'ofile': ofile}\n\n\nclass PreprocessAdapterFromNpy(DataLoader):\n    def __init__(self, list_of_images: List[np.ndarray],\n                 list_of_segs_from_prev_stage: Union[List[np.ndarray], None],\n                 list_of_image_properties: List[dict],\n                 truncated_ofnames: Union[List[str], None],\n                 plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager,\n                 num_threads_in_multithreaded: int = 1, verbose: bool = False):\n        preprocessor = configuration_manager.preprocessor_class(verbose=verbose)\n        self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json, self.truncated_ofnames = \\\n            preprocessor, plans_manager, configuration_manager, dataset_json, truncated_ofnames\n\n        self.label_manager = plans_manager.get_label_manager(dataset_json)\n\n        if list_of_segs_from_prev_stage is None:\n            list_of_segs_from_prev_stage = [None] * len(list_of_images)\n        if truncated_ofnames is None:\n            truncated_ofnames = [None] * len(list_of_images)\n\n        super().__init__(\n            list(zip(list_of_images, list_of_segs_from_prev_stage, list_of_image_properties, truncated_ofnames)),\n            1, num_threads_in_multithreaded,\n            seed_for_shuffle=1, return_incomplete=True,\n            shuffle=False, infinite=False, sampling_probabilities=None)\n\n        self.indices = list(range(len(list_of_images)))\n\n    def generate_train_batch(self):\n        idx = self.get_indices()[0]\n        image, seg_prev_stage, props, ofname = self._data[idx]\n        # if we have a segmentation from the previous stage we have to process it together with the images so that we\n        # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after\n        # preprocessing and then there might be misalignments\n        data, seg, props = self.preprocessor.run_case_npy(image, seg_prev_stage, props,\n                                                   self.plans_manager,\n                                                   self.configuration_manager,\n                                                   self.dataset_json)\n        if seg_prev_stage is not None:\n            seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype)\n            data = np.vstack((data, seg_onehot))\n\n        data = torch.from_numpy(data)\n\n        return {'data': data, 'data_properties': props, 'ofile': ofname}\n\n\ndef preprocess_fromnpy_save_to_queue(list_of_images: List[np.ndarray],\n                                     list_of_segs_from_prev_stage: Union[List[np.ndarray], None],\n                                     list_of_image_properties: List[dict],\n                                     truncated_ofnames: Union[List[str], None],\n                                     plans_manager: PlansManager,\n                                     dataset_json: dict,\n                                     configuration_manager: ConfigurationManager,\n                                     target_queue: Queue,\n                                     done_event: Event,\n                                     abort_event: Event,\n                                     verbose: bool = False):\n    try:\n        label_manager = plans_manager.get_label_manager(dataset_json)\n        preprocessor = configuration_manager.preprocessor_class(verbose=verbose)\n        for idx in range(len(list_of_images)):\n            data, seg, props = preprocessor.run_case_npy(list_of_images[idx],\n                                                  list_of_segs_from_prev_stage[\n                                                      idx] if list_of_segs_from_prev_stage is not None else None,\n                                                  list_of_image_properties[idx],\n                                                  plans_manager,\n                                                  configuration_manager,\n                                                  dataset_json)\n            list_of_image_properties[idx] = props\n            if list_of_segs_from_prev_stage is not None and list_of_segs_from_prev_stage[idx] is not None:\n                seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype)\n                data = np.vstack((data, seg_onehot))\n\n            data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)\n\n            item = {'data': data, 'data_properties': list_of_image_properties[idx],\n                    'ofile': truncated_ofnames[idx] if truncated_ofnames is not None else None}\n            success = False\n            while not success:\n                try:\n                    if abort_event.is_set():\n                        return\n                    target_queue.put(item, timeout=0.01)\n                    success = True\n                except queue.Full:\n                    pass\n        done_event.set()\n    except Exception as e:\n        abort_event.set()\n        raise e\n\n\ndef preprocessing_iterator_fromnpy(list_of_images: List[np.ndarray],\n                                   list_of_segs_from_prev_stage: Union[List[np.ndarray], None],\n                                   list_of_image_properties: List[dict],\n                                   truncated_ofnames: Union[List[str], None],\n                                   plans_manager: PlansManager,\n                                   dataset_json: dict,\n                                   configuration_manager: ConfigurationManager,\n                                   num_processes: int,\n                                   pin_memory: bool = False,\n                                   verbose: bool = False):\n    context = multiprocessing.get_context('spawn')\n    manager = Manager()\n    num_processes = min(len(list_of_images), num_processes)\n    assert num_processes >= 1\n    target_queues = []\n    processes = []\n    done_events = []\n    abort_event = manager.Event()\n    for i in range(num_processes):\n        event = manager.Event()\n        queue = manager.Queue(maxsize=1)\n        pr = context.Process(target=preprocess_fromnpy_save_to_queue,\n                     args=(\n                         list_of_images[i::num_processes],\n                         list_of_segs_from_prev_stage[\n                         i::num_processes] if list_of_segs_from_prev_stage is not None else None,\n                         list_of_image_properties[i::num_processes],\n                         truncated_ofnames[i::num_processes] if truncated_ofnames is not None else None,\n                         plans_manager,\n                         dataset_json,\n                         configuration_manager,\n                         queue,\n                         event,\n                         abort_event,\n                         verbose\n                     ), daemon=True)\n        pr.start()\n        done_events.append(event)\n        processes.append(pr)\n        target_queues.append(queue)\n\n    worker_ctr = 0\n    while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()):\n        if not target_queues[worker_ctr].empty():\n            item = target_queues[worker_ctr].get()\n            worker_ctr = (worker_ctr + 1) % num_processes\n        else:\n            all_ok = all(\n                [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set()\n            if not all_ok:\n                raise RuntimeError('Background workers died. Look for the error message further up! If there is '\n                                   'none then your RAM was full and the worker was killed by the OS. Use fewer '\n                                   'workers or get more RAM in that case!')\n            sleep(0.01)\n            continue\n        if pin_memory:\n            [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)]\n        yield item\n    [p.join() for p in processes]\n"
  },
  {
    "path": "nnunetv2/inference/examples.py",
    "content": "if __name__ == '__main__':\n    from nnunetv2.paths import nnUNet_results, nnUNet_raw\n    import torch\n    from batchgenerators.utilities.file_and_folder_operations import join\n    from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\n    from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO\n\n    # nnUNetv2_predict -d 3 -f 0 -c 3d_lowres -i imagesTs -o imagesTs_predlowres --continue_prediction\n\n    # instantiate the nnUNetPredictor\n    predictor = nnUNetPredictor(\n        tile_step_size=0.5,\n        use_gaussian=True,\n        use_mirroring=True,\n        perform_everything_on_device=True,\n        device=torch.device('cuda', 0),\n        verbose=False,\n        verbose_preprocessing=False,\n        allow_tqdm=True\n    )\n    # initializes the network architecture, loads the checkpoint\n    predictor.initialize_from_trained_model_folder(\n        join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'),\n        use_folds=(0,),\n        checkpoint_name='checkpoint_final.pth',\n    )\n    # variant 1: give input and output folders\n    predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),\n                                 join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),\n                                 save_probabilities=False, overwrite=False,\n                                 num_processes_preprocessing=2, num_processes_segmentation_export=2,\n                                 folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)\n\n    # variant 2, use list of files as inputs. Note how we use nested lists!!!\n    indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs')\n    outdir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres')\n    predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')],\n                                  [join(indir, 'liver_142_0000.nii.gz')]],\n                                 [join(outdir, 'liver_152.nii.gz'),\n                                  join(outdir, 'liver_142.nii.gz')],\n                                 save_probabilities=False, overwrite=True,\n                                 num_processes_preprocessing=2, num_processes_segmentation_export=2,\n                                 folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)\n\n    # variant 2.5, returns segmentations\n    indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs')\n    predicted_segmentations = predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')],\n                                                            [join(indir, 'liver_142_0000.nii.gz')]],\n                                                           None,\n                                                           save_probabilities=True, overwrite=True,\n                                                           num_processes_preprocessing=2,\n                                                           num_processes_segmentation_export=2,\n                                                           folder_with_segs_from_prev_stage=None, num_parts=1,\n                                                           part_id=0)\n\n    # predict several npy images\n    from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO\n\n    img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])\n    img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')])\n    img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')])\n    img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')])\n    # we do not set output files so that the segmentations will be returned. You can of course also specify output\n    # files instead (no return value on that case)\n    ret = predictor.predict_from_list_of_npy_arrays([img, img2, img3, img4],\n                                                    None,\n                                                    [props, props2, props3, props4],\n                                                    None, 2, save_probabilities=False,\n                                                    num_processes_segmentation_export=2)\n\n    # predict a single numpy array\n    img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])\n    ret = predictor.predict_single_npy_array(img, props, None, None, True)\n\n    # custom iterator\n\n    img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])\n    img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')])\n    img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')])\n    img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')])\n\n\n    # each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys!\n    # If 'ofile' is None, the result will be returned instead of written to a file\n    # the iterator is responsible for performing the correct preprocessing!\n    # note how the iterator here does not use multiprocessing -> preprocessing will be done in the main thread!\n    # take a look at the default iterators for predict_from_files and predict_from_list_of_npy_arrays\n    # (they both use predictor.predict_from_data_iterator) for inspiration!\n    def my_iterator(list_of_input_arrs, list_of_input_props):\n        preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose)\n        for a, p in zip(list_of_input_arrs, list_of_input_props):\n            data, seg, p = preprocessor.run_case_npy(a,\n                                                  None,\n                                                  p,\n                                                  predictor.plans_manager,\n                                                  predictor.configuration_manager,\n                                                  predictor.dataset_json)\n            yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properties': p, 'ofile': None}\n\n\n    ret = predictor.predict_from_data_iterator(my_iterator([img, img2, img3, img4], [props, props2, props3, props4]),\n                                               save_probabilities=False, num_processes_segmentation_export=3)\n"
  },
  {
    "path": "nnunetv2/inference/export_prediction.py",
    "content": "from typing import Union, List\n\nimport numpy as np\nimport torch\nfrom acvl_utils.cropping_and_padding.bounding_boxes import insert_crop_into_image\nfrom batchgenerators.utilities.file_and_folder_operations import load_json, save_pickle\n\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.training.dataloading.nnunet_dataset import nnUNetDatasetBlosc2\nfrom nnunetv2.utilities.label_handling.label_handling import LabelManager\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager\n\n\ndef convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray],\n                                                                plans_manager: PlansManager,\n                                                                configuration_manager: ConfigurationManager,\n                                                                label_manager: LabelManager,\n                                                                properties_dict: dict,\n                                                                return_probabilities: bool = False,\n                                                                num_threads_torch: int = default_num_processes):\n    old_threads = torch.get_num_threads()\n    torch.set_num_threads(num_threads_torch)\n\n    # resample to original shape\n    spacing_transposed = [properties_dict['spacing'][i] for i in plans_manager.transpose_forward]\n    current_spacing = configuration_manager.spacing if \\\n        len(configuration_manager.spacing) == \\\n        len(properties_dict['shape_after_cropping_and_before_resampling']) else \\\n        [spacing_transposed[0], *configuration_manager.spacing]\n    predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits,\n                                            properties_dict['shape_after_cropping_and_before_resampling'],\n                                            current_spacing,\n                                            [properties_dict['spacing'][i] for i in plans_manager.transpose_forward])\n    # return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because\n    # apply_inference_nonlin will convert to torch\n    if not return_probabilities:\n        # this has a faster computation path becasue we can skip the softmax in regular (not region based) trainig\n        segmentation = label_manager.convert_logits_to_segmentation(predicted_logits)\n    else:\n        predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits)\n        segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities)\n    del predicted_logits\n\n    # put segmentation in bbox (revert cropping)\n    segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'],\n                                              dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16)\n    segmentation_reverted_cropping = insert_crop_into_image(segmentation_reverted_cropping, segmentation, properties_dict['bbox_used_for_cropping'])\n    del segmentation\n\n    # segmentation may be torch.Tensor but we continue with numpy\n    if isinstance(segmentation_reverted_cropping, torch.Tensor):\n        segmentation_reverted_cropping = segmentation_reverted_cropping.cpu().numpy()\n\n    # revert transpose\n    segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward)\n    if return_probabilities:\n        # revert cropping\n        predicted_probabilities = label_manager.revert_cropping_on_probabilities(predicted_probabilities,\n                                                                                 properties_dict[\n                                                                                     'bbox_used_for_cropping'],\n                                                                                 properties_dict[\n                                                                                     'shape_before_cropping'])\n        predicted_probabilities = predicted_probabilities.cpu().numpy()\n        # revert transpose\n        predicted_probabilities = predicted_probabilities.transpose([0] + [i + 1 for i in\n                                                                           plans_manager.transpose_backward])\n        torch.set_num_threads(old_threads)\n        return segmentation_reverted_cropping, predicted_probabilities\n    else:\n        torch.set_num_threads(old_threads)\n        return segmentation_reverted_cropping\n\n\ndef export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict,\n                                  configuration_manager: ConfigurationManager,\n                                  plans_manager: PlansManager,\n                                  dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str,\n                                  save_probabilities: bool = False,\n                                  num_threads_torch: int = default_num_processes):\n    # if isinstance(predicted_array_or_file, str):\n    #     tmp = deepcopy(predicted_array_or_file)\n    #     if predicted_array_or_file.endswith('.npy'):\n    #         predicted_array_or_file = np.load(predicted_array_or_file)\n    #     elif predicted_array_or_file.endswith('.npz'):\n    #         predicted_array_or_file = np.load(predicted_array_or_file)['softmax']\n    #     os.remove(tmp)\n\n    if isinstance(dataset_json_dict_or_file, str):\n        dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)\n\n    label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)\n    ret = convert_predicted_logits_to_segmentation_with_correct_shape(\n        predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict,\n        return_probabilities=save_probabilities, num_threads_torch=num_threads_torch\n    )\n    del predicted_array_or_file\n\n    # save\n    if save_probabilities:\n        segmentation_final, probabilities_final = ret\n        np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final)\n        save_pickle(properties_dict, output_file_truncated + '.pkl')\n        del probabilities_final, ret\n    else:\n        segmentation_final = ret\n        del ret\n\n    rw = plans_manager.image_reader_writer_class()\n    rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'],\n                 properties_dict)\n\n\ndef resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: List[int], output_file: str,\n                      plans_manager: PlansManager, configuration_manager: ConfigurationManager, properties_dict: dict,\n                      dataset_json_dict_or_file: Union[dict, str], num_threads_torch: int = default_num_processes,\n                      dataset_class=None) \\\n        -> None:\n\n    old_threads = torch.get_num_threads()\n    torch.set_num_threads(num_threads_torch)\n\n    if isinstance(dataset_json_dict_or_file, str):\n        dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)\n\n    spacing_transposed = [properties_dict['spacing'][i] for i in plans_manager.transpose_forward]\n    # resample to original shape\n    current_spacing = configuration_manager.spacing if \\\n        len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \\\n        [spacing_transposed[0], *configuration_manager.spacing]\n    target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \\\n        len(properties_dict['shape_after_cropping_and_before_resampling']) else \\\n        [spacing_transposed[0], *configuration_manager.spacing]\n    predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted,\n                                                                                target_shape,\n                                                                                current_spacing,\n                                                                                target_spacing)\n\n    # create segmentation (argmax, regions, etc)\n    label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)\n    segmentation = label_manager.convert_logits_to_segmentation(predicted_array_or_file)\n    # segmentation may be torch.Tensor but we continue with numpy\n    if isinstance(segmentation, torch.Tensor):\n        segmentation = segmentation.cpu().numpy()\n\n    if dataset_class is None:\n        nnUNetDatasetBlosc2.save_seg(segmentation.astype(dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16), output_file)\n    else:\n        dataset_class.save_seg(segmentation.astype(dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16), output_file)\n    torch.set_num_threads(old_threads)\n"
  },
  {
    "path": "nnunetv2/inference/predict_from_raw_data.py",
    "content": "import inspect\nimport itertools\nimport multiprocessing\nimport os\nfrom copy import deepcopy\nfrom queue import Queue\nfrom threading import Thread\nfrom time import sleep\nfrom typing import Tuple, Union, List, Optional\n\nimport numpy as np\nimport torch\nfrom acvl_utils.cropping_and_padding.padding import pad_nd_image\nfrom batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter\nfrom batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \\\n    save_json\nfrom torch import nn\nfrom torch._dynamo import OptimizedModule\nfrom torch.nn.parallel import DistributedDataParallel\nfrom tqdm import tqdm\n\nimport nnunetv2\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.inference.data_iterators import PreprocessAdapterFromNpy, preprocessing_iterator_fromfiles, \\\n    preprocessing_iterator_fromnpy\nfrom nnunetv2.inference.export_prediction import export_prediction_from_logits, \\\n    convert_predicted_logits_to_segmentation_with_correct_shape\nfrom nnunetv2.inference.sliding_window_prediction import compute_gaussian, \\\n    compute_steps_for_sliding_window\nfrom nnunetv2.utilities.file_path_utilities import get_output_folder, check_workers_alive_and_busy\nfrom nnunetv2.utilities.find_class_by_name import recursive_find_python_class\nfrom nnunetv2.utilities.helpers import empty_cache, dummy_context\nfrom nnunetv2.utilities.json_export import recursive_fix_for_json_export\nfrom nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager\nfrom nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder\n\n\nclass nnUNetPredictor(object):\n    def __init__(self,\n                 tile_step_size: float = 0.5,\n                 use_gaussian: bool = True,\n                 use_mirroring: bool = True,\n                 perform_everything_on_device: bool = True,\n                 device: torch.device = torch.device('cuda'),\n                 verbose: bool = False,\n                 verbose_preprocessing: bool = False,\n                 allow_tqdm: bool = True):\n        self.verbose = verbose\n        self.verbose_preprocessing = verbose_preprocessing\n        self.allow_tqdm = allow_tqdm\n\n        self.plans_manager, self.configuration_manager, self.list_of_parameters, self.network, self.dataset_json, \\\n        self.trainer_name, self.allowed_mirroring_axes, self.label_manager = None, None, None, None, None, None, None, None\n\n        self.tile_step_size = tile_step_size\n        self.use_gaussian = use_gaussian\n        self.use_mirroring = use_mirroring\n        if device.type == 'cuda':\n            torch.backends.cudnn.benchmark = True\n        else:\n            print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False')\n            perform_everything_on_device = False\n        self.device = device\n        self.perform_everything_on_device = perform_everything_on_device\n\n    def initialize_from_trained_model_folder(self, model_training_output_dir: str,\n                                             use_folds: Union[Tuple[Union[int, str]], None],\n                                             checkpoint_name: str = 'checkpoint_final.pth'):\n        \"\"\"\n        This is used when making predictions with a trained model\n        \"\"\"\n        if use_folds is None:\n            use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name)\n\n        dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))\n        plans = load_json(join(model_training_output_dir, 'plans.json'))\n        plans_manager = PlansManager(plans)\n\n        if isinstance(use_folds, str):\n            use_folds = [use_folds]\n\n        parameters = []\n        for i, f in enumerate(use_folds):\n            f = int(f) if f != 'all' else f\n            checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),\n                                    map_location=torch.device('cpu'), weights_only=False)\n            if i == 0:\n                trainer_name = checkpoint['trainer_name']\n                configuration_name = checkpoint['init_args']['configuration']\n                inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \\\n                    'inference_allowed_mirroring_axes' in checkpoint.keys() else None\n\n            parameters.append(checkpoint['network_weights'])\n\n        configuration_manager = plans_manager.get_configuration(configuration_name)\n        # restore network\n        num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)\n        trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], \"training\", \"nnUNetTrainer\"),\n                                                    trainer_name, 'nnunetv2.training.nnUNetTrainer')\n        if trainer_class is None:\n            raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '\n                               f'Please place it there (in any .py file)!')\n        network = trainer_class.build_network_architecture(\n            configuration_manager.network_arch_class_name,\n            configuration_manager.network_arch_init_kwargs,\n            configuration_manager.network_arch_init_kwargs_req_import,\n            num_input_channels,\n            plans_manager.get_label_manager(dataset_json).num_segmentation_heads,\n            enable_deep_supervision=False\n        )\n\n        self.plans_manager = plans_manager\n        self.configuration_manager = configuration_manager\n        self.list_of_parameters = parameters\n\n        # initialize network with first set of parameters, also see https://github.com/MIC-DKFZ/nnUNet/issues/2520\n        network.load_state_dict(parameters[0])\n\n        self.network = network\n\n        self.dataset_json = dataset_json\n        self.trainer_name = trainer_name\n        self.allowed_mirroring_axes = inference_allowed_mirroring_axes\n        self.label_manager = plans_manager.get_label_manager(dataset_json)\n        if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \\\n                and not isinstance(self.network, OptimizedModule):\n            print('Using torch.compile')\n            self.network = torch.compile(self.network)\n\n    def manual_initialization(self, network: nn.Module, plans_manager: PlansManager,\n                              configuration_manager: ConfigurationManager, parameters: Optional[List[dict]],\n                              dataset_json: dict, trainer_name: str,\n                              inference_allowed_mirroring_axes: Optional[Tuple[int, ...]]):\n        \"\"\"\n        This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation\n        \"\"\"\n        self.plans_manager = plans_manager\n        self.configuration_manager = configuration_manager\n        self.list_of_parameters = parameters\n        self.network = network\n        self.dataset_json = dataset_json\n        self.trainer_name = trainer_name\n        self.allowed_mirroring_axes = inference_allowed_mirroring_axes\n        self.label_manager = plans_manager.get_label_manager(dataset_json)\n        allow_compile = True\n        allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and (\n                    os.environ['nnUNet_compile'].lower() in ('true', '1', 't'))\n        allow_compile = allow_compile and not isinstance(self.network, OptimizedModule)\n        if isinstance(self.network, DistributedDataParallel):\n            allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule)\n        if allow_compile:\n            print('Using torch.compile')\n            self.network = torch.compile(self.network)\n\n    @staticmethod\n    def auto_detect_available_folds(model_training_output_dir, checkpoint_name):\n        print('use_folds is None, attempting to auto detect available folds')\n        fold_folders = subdirs(model_training_output_dir, prefix='fold_', join=False)\n        fold_folders = [i for i in fold_folders if i != 'fold_all']\n        fold_folders = [i for i in fold_folders if isfile(join(model_training_output_dir, i, checkpoint_name))]\n        use_folds = [int(i.split('_')[-1]) for i in fold_folders]\n        print(f'found the following folds: {use_folds}')\n        return use_folds\n\n    def _manage_input_and_output_lists(self, list_of_lists_or_source_folder: Union[str, List[List[str]]],\n                                       output_folder_or_list_of_truncated_output_files: Union[None, str, List[str]],\n                                       folder_with_segs_from_prev_stage: str = None,\n                                       overwrite: bool = True,\n                                       part_id: int = 0,\n                                       num_parts: int = 1,\n                                       save_probabilities: bool = False):\n        if isinstance(list_of_lists_or_source_folder, str):\n            list_of_lists_or_source_folder = create_lists_from_splitted_dataset_folder(list_of_lists_or_source_folder,\n                                                                                       self.dataset_json['file_ending'])\n        print(f'There are {len(list_of_lists_or_source_folder)} cases in the source folder')\n        list_of_lists_or_source_folder = list_of_lists_or_source_folder[part_id::num_parts]\n        caseids = [os.path.basename(i[0])[:-(len(self.dataset_json['file_ending']) + 5)] for i in\n                   list_of_lists_or_source_folder]\n        print(\n            f'I am processing {part_id} out of {num_parts} (max process ID is {num_parts - 1}, we start counting with 0!)')\n        print(f'There are {len(caseids)} cases that I would like to predict')\n\n        if isinstance(output_folder_or_list_of_truncated_output_files, str):\n            output_filename_truncated = [join(output_folder_or_list_of_truncated_output_files, i) for i in caseids]\n        elif isinstance(output_folder_or_list_of_truncated_output_files, list):\n            output_filename_truncated = output_folder_or_list_of_truncated_output_files[part_id::num_parts]\n        else:\n            output_filename_truncated = None\n        seg_from_prev_stage_files = [join(folder_with_segs_from_prev_stage, i + self.dataset_json['file_ending']) if\n                                     folder_with_segs_from_prev_stage is not None else None for i in caseids]\n        # remove already predicted files from the lists\n        if not overwrite and output_filename_truncated is not None:\n            tmp = [isfile(i + self.dataset_json['file_ending']) for i in output_filename_truncated]\n            if save_probabilities:\n                tmp2 = [isfile(i + '.npz') for i in output_filename_truncated]\n                tmp = [i and j for i, j in zip(tmp, tmp2)]\n            not_existing_indices = [i for i, j in enumerate(tmp) if not j]\n\n            output_filename_truncated = [output_filename_truncated[i] for i in not_existing_indices]\n            list_of_lists_or_source_folder = [list_of_lists_or_source_folder[i] for i in not_existing_indices]\n            seg_from_prev_stage_files = [seg_from_prev_stage_files[i] for i in not_existing_indices]\n            print(f'overwrite was set to {overwrite}, so I am only working on cases that haven\\'t been predicted yet. '\n                  f'That\\'s {len(not_existing_indices)} cases.')\n        return list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files\n\n    def predict_from_files(self,\n                           list_of_lists_or_source_folder: Union[str, List[List[str]]],\n                           output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]],\n                           save_probabilities: bool = False,\n                           overwrite: bool = True,\n                           num_processes_preprocessing: int = default_num_processes,\n                           num_processes_segmentation_export: int = default_num_processes,\n                           folder_with_segs_from_prev_stage: str = None,\n                           num_parts: int = 1,\n                           part_id: int = 0):\n        \"\"\"\n        This is nnU-Net's default function for making predictions. It works best for batch predictions\n        (predicting many images at once).\n        \"\"\"\n        assert part_id <= num_parts, (\"Part ID must be smaller than num_parts. Remember that we start counting with 0. \"\n                                      \"So if there are 3 parts then valid part IDs are 0, 1, 2\")\n        if isinstance(output_folder_or_list_of_truncated_output_files, str):\n            output_folder = output_folder_or_list_of_truncated_output_files\n        elif isinstance(output_folder_or_list_of_truncated_output_files, list):\n            output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0])\n        else:\n            output_folder = None\n\n        ########################\n        # let's store the input arguments so that its clear what was used to generate the prediction\n        if output_folder is not None:\n            my_init_kwargs = {}\n            for k in inspect.signature(self.predict_from_files).parameters.keys():\n                my_init_kwargs[k] = locals()[k]\n            my_init_kwargs = deepcopy(\n                my_init_kwargs)  # let's not unintentionally change anything in-place. Take this as a\n            recursive_fix_for_json_export(my_init_kwargs)\n            maybe_mkdir_p(output_folder)\n            save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json'))\n\n            # we need these two if we want to do things with the predictions like for example apply postprocessing\n            save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)\n            save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False)\n        #######################\n\n        # check if we need a prediction from the previous stage\n        if self.configuration_manager.previous_stage_name is not None:\n            assert folder_with_segs_from_prev_stage is not None, \\\n                f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \\\n                f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \\\n                f' they are located via folder_with_segs_from_prev_stage'\n\n        # sort out input and output filenames\n        list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \\\n            self._manage_input_and_output_lists(list_of_lists_or_source_folder,\n                                                output_folder_or_list_of_truncated_output_files,\n                                                folder_with_segs_from_prev_stage, overwrite, part_id, num_parts,\n                                                save_probabilities)\n        if len(list_of_lists_or_source_folder) == 0:\n            return\n\n        data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder,\n                                                                                 seg_from_prev_stage_files,\n                                                                                 output_filename_truncated,\n                                                                                 num_processes_preprocessing)\n\n        return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export)\n\n    def _internal_get_data_iterator_from_lists_of_filenames(self,\n                                                            input_list_of_lists: List[List[str]],\n                                                            seg_from_prev_stage_files: Union[List[str], None],\n                                                            output_filenames_truncated: Union[List[str], None],\n                                                            num_processes: int):\n        return preprocessing_iterator_fromfiles(input_list_of_lists, seg_from_prev_stage_files,\n                                                output_filenames_truncated, self.plans_manager, self.dataset_json,\n                                                self.configuration_manager, num_processes, self.device.type == 'cuda',\n                                                self.verbose_preprocessing)\n        # preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose_preprocessing)\n        # # hijack batchgenerators, yo\n        # # we use the multiprocessing of the batchgenerators dataloader to handle all the background worker stuff. This\n        # # way we don't have to reinvent the wheel here.\n        # num_processes = max(1, min(num_processes, len(input_list_of_lists)))\n        # ppa = PreprocessAdapter(input_list_of_lists, seg_from_prev_stage_files, preprocessor,\n        #                         output_filenames_truncated, self.plans_manager, self.dataset_json,\n        #                         self.configuration_manager, num_processes)\n        # if num_processes == 0:\n        #     mta = SingleThreadedAugmenter(ppa, None)\n        # else:\n        #     mta = MultiThreadedAugmenter(ppa, None, num_processes, 1, None, pin_memory=pin_memory)\n        # return mta\n\n    def get_data_iterator_from_raw_npy_data(self,\n                                            image_or_list_of_images: Union[np.ndarray, List[np.ndarray]],\n                                            segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None,\n                                                                                                        np.ndarray,\n                                                                                                        List[\n                                                                                                            np.ndarray]],\n                                            properties_or_list_of_properties: Union[dict, List[dict]],\n                                            truncated_ofname: Union[str, List[str], None],\n                                            num_processes: int = 3):\n\n        list_of_images = [image_or_list_of_images] if not isinstance(image_or_list_of_images, list) else \\\n            image_or_list_of_images\n\n        if isinstance(segs_from_prev_stage_or_list_of_segs_from_prev_stage, np.ndarray):\n            segs_from_prev_stage_or_list_of_segs_from_prev_stage = [\n                segs_from_prev_stage_or_list_of_segs_from_prev_stage]\n\n        if isinstance(truncated_ofname, str):\n            truncated_ofname = [truncated_ofname]\n\n        if isinstance(properties_or_list_of_properties, dict):\n            properties_or_list_of_properties = [properties_or_list_of_properties]\n\n        num_processes = min(num_processes, len(list_of_images))\n        pp = preprocessing_iterator_fromnpy(\n            list_of_images,\n            segs_from_prev_stage_or_list_of_segs_from_prev_stage,\n            properties_or_list_of_properties,\n            truncated_ofname,\n            self.plans_manager,\n            self.dataset_json,\n            self.configuration_manager,\n            num_processes,\n            self.device.type == 'cuda',\n            self.verbose_preprocessing\n        )\n\n        return pp\n\n    def predict_from_list_of_npy_arrays(self,\n                                        image_or_list_of_images: Union[np.ndarray, List[np.ndarray]],\n                                        segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None,\n                                                                                                    np.ndarray,\n                                                                                                    List[\n                                                                                                        np.ndarray]],\n                                        properties_or_list_of_properties: Union[dict, List[dict]],\n                                        truncated_ofname: Union[str, List[str], None],\n                                        num_processes: int = 3,\n                                        save_probabilities: bool = False,\n                                        num_processes_segmentation_export: int = default_num_processes):\n        iterator = self.get_data_iterator_from_raw_npy_data(image_or_list_of_images,\n                                                            segs_from_prev_stage_or_list_of_segs_from_prev_stage,\n                                                            properties_or_list_of_properties,\n                                                            truncated_ofname,\n                                                            num_processes)\n        return self.predict_from_data_iterator(iterator, save_probabilities, num_processes_segmentation_export)\n\n    def predict_from_data_iterator(self,\n                                   data_iterator,\n                                   save_probabilities: bool = False,\n                                   num_processes_segmentation_export: int = default_num_processes):\n        \"\"\"\n        each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys!\n        If 'ofile' is None, the result will be returned instead of written to a file\n        \"\"\"\n        with multiprocessing.get_context(\"spawn\").Pool(num_processes_segmentation_export) as export_pool:\n            worker_list = [i for i in export_pool._pool]\n            r = []\n            for preprocessed in data_iterator:\n                data = preprocessed['data']\n                if isinstance(data, str):\n                    delfile = data\n                    data = torch.from_numpy(np.load(data))\n                    os.remove(delfile)\n\n                ofile = preprocessed['ofile']\n                if ofile is not None:\n                    print(f'\\nPredicting {os.path.basename(ofile)}:')\n                else:\n                    print(f'\\nPredicting image of shape {data.shape}:')\n\n                print(f'perform_everything_on_device: {self.perform_everything_on_device}')\n\n                properties = preprocessed['data_properties']\n\n                # let's not get into a runaway situation where the GPU predicts so fast that the disk has to be swamped with\n                # npy files\n                proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)\n                while not proceed:\n                    sleep(0.1)\n                    proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)\n\n                # convert to numpy to prevent uncatchable memory alignment errors from multiprocessing serialization of torch tensors\n                prediction = self.predict_logits_from_preprocessed_data(data).cpu().detach().numpy()\n\n                if ofile is not None:\n                    print('sending off prediction to background worker for resampling and export')\n                    r.append(\n                        export_pool.starmap_async(\n                            export_prediction_from_logits,\n                            ((prediction, properties, self.configuration_manager, self.plans_manager,\n                              self.dataset_json, ofile, save_probabilities),)\n                        )\n                    )\n                else:\n                    print('sending off prediction to background worker for resampling')\n                    r.append(\n                        export_pool.starmap_async(\n                            convert_predicted_logits_to_segmentation_with_correct_shape, (\n                                (prediction, self.plans_manager,\n                                 self.configuration_manager, self.label_manager,\n                                 properties,\n                                 save_probabilities),)\n                        )\n                    )\n                if ofile is not None:\n                    print(f'done with {os.path.basename(ofile)}')\n                else:\n                    print(f'\\nDone with image of shape {data.shape}:')\n            ret = [i.get()[0] for i in r]\n\n        if isinstance(data_iterator, MultiThreadedAugmenter):\n            data_iterator._finish()\n\n        # clear lru cache\n        compute_gaussian.cache_clear()\n        # clear device cache\n        empty_cache(self.device)\n        return ret\n\n    def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,\n                                 segmentation_previous_stage: np.ndarray = None,\n                                 output_file_truncated: str = None,\n                                 save_or_return_probabilities: bool = False):\n        \"\"\"\n        WARNING: SLOW. ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON.\n\n\n        input_image: Make sure to load the image in the way nnU-Net expects! nnU-Net is trained on a certain axis\n                     ordering which cannot be disturbed in inference,\n                     otherwise you will get bad results. The easiest way to achieve that is to use the same I/O class\n                     for loading images as was used during nnU-Net preprocessing! You can find that class in your\n                     plans.json file under the key \"image_reader_writer\". If you decide to freestyle, know that the\n                     default axis ordering for medical images is the one from SimpleITK. If you load with nibabel,\n                     you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]!\n        image_properties must only have a 'spacing' key!\n        \"\"\"\n        ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties],\n                                       [output_file_truncated],\n                                       self.plans_manager, self.dataset_json, self.configuration_manager,\n                                       num_threads_in_multithreaded=1, verbose=self.verbose)\n        if self.verbose:\n            print('preprocessing')\n        dct = next(ppa)\n\n        if self.verbose:\n            print('predicting')\n        predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu()\n\n        if self.verbose:\n            print('resampling to original shape')\n        if output_file_truncated is not None:\n            export_prediction_from_logits(predicted_logits, dct['data_properties'], self.configuration_manager,\n                                          self.plans_manager, self.dataset_json, output_file_truncated,\n                                          save_or_return_probabilities)\n        else:\n            ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,\n                                                                              self.configuration_manager,\n                                                                              self.label_manager,\n                                                                              dct['data_properties'],\n                                                                              return_probabilities=\n                                                                              save_or_return_probabilities)\n            if save_or_return_probabilities:\n                return ret[0], ret[1]\n            else:\n                return ret\n\n    @torch.inference_mode()\n    def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON\n        TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!\n\n        RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE.\n        SEE convert_predicted_logits_to_segmentation_with_correct_shape\n        \"\"\"\n        n_threads = torch.get_num_threads()\n        torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads)\n        prediction = None\n\n        for params in self.list_of_parameters:\n\n            # messing with state dict names...\n            if not isinstance(self.network, OptimizedModule):\n                self.network.load_state_dict(params)\n            else:\n                self.network._orig_mod.load_state_dict(params)\n\n            # why not leave prediction on device if perform_everything_on_device? Because this may cause the\n            # second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than\n            # this actually saves computation time\n            if prediction is None:\n                prediction = self.predict_sliding_window_return_logits(data).to('cpu')\n            else:\n                prediction += self.predict_sliding_window_return_logits(data).to('cpu')\n\n        if len(self.list_of_parameters) > 1:\n            prediction /= len(self.list_of_parameters)\n\n        if self.verbose: print('Prediction done')\n        torch.set_num_threads(n_threads)\n        return prediction\n\n    def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]):\n        slicers = []\n        if len(self.configuration_manager.patch_size) < len(image_size):\n            assert len(self.configuration_manager.patch_size) == len(\n                image_size) - 1, 'if tile_size has less entries than image_size, ' \\\n                                 'len(tile_size) ' \\\n                                 'must be one shorter than len(image_size) ' \\\n                                 '(only dimension ' \\\n                                 'discrepancy of 1 allowed).'\n            steps = compute_steps_for_sliding_window(image_size[1:], self.configuration_manager.patch_size,\n                                                     self.tile_step_size)\n            if self.verbose: print(f'n_steps {image_size[0] * len(steps[0]) * len(steps[1])}, image size is'\n                                   f' {image_size}, tile_size {self.configuration_manager.patch_size}, '\n                                   f'tile_step_size {self.tile_step_size}\\nsteps:\\n{steps}')\n            for d in range(image_size[0]):\n                for sx in steps[0]:\n                    for sy in steps[1]:\n                        slicers.append(\n                            tuple([slice(None), d, *[slice(si, si + ti) for si, ti in\n                                                     zip((sx, sy), self.configuration_manager.patch_size)]]))\n        else:\n            steps = compute_steps_for_sliding_window(image_size, self.configuration_manager.patch_size,\n                                                     self.tile_step_size)\n            if self.verbose: print(\n                f'n_steps {np.prod([len(i) for i in steps])}, image size is {image_size}, tile_size {self.configuration_manager.patch_size}, '\n                f'tile_step_size {self.tile_step_size}\\nsteps:\\n{steps}')\n            for sx in steps[0]:\n                for sy in steps[1]:\n                    for sz in steps[2]:\n                        slicers.append(\n                            tuple([slice(None), *[slice(si, si + ti) for si, ti in\n                                                  zip((sx, sy, sz), self.configuration_manager.patch_size)]]))\n        return slicers\n\n    @torch.inference_mode()\n    def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:\n        mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None\n        prediction = self.network(x)\n\n        if mirror_axes is not None:\n            # check for invalid numbers in mirror_axes\n            # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3\n            assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'\n\n            mirror_axes = [m + 2 for m in mirror_axes]\n            axes_combinations = [\n                c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)\n            ]\n            for axes in axes_combinations:\n                prediction += torch.flip(self.network(torch.flip(x, axes)), axes)\n            prediction /= (len(axes_combinations) + 1)\n        return prediction\n\n    @torch.inference_mode()\n    def _internal_predict_sliding_window_return_logits(self,\n                                                       data: torch.Tensor,\n                                                       slicers,\n                                                       do_on_device: bool = True,\n                                                       ):\n        predicted_logits = n_predictions = prediction = gaussian = workon = None\n        results_device = self.device if do_on_device else torch.device('cpu')\n\n        def producer(d, slh, q):\n            for s in slh:\n                q.put((torch.clone(d[s][None], memory_format=torch.contiguous_format).to(self.device), s))\n            q.put('end')\n\n        try:\n            empty_cache(self.device)\n\n            # move data to device\n            if self.verbose:\n                print(f'move image to device {results_device}')\n            data = data.to(results_device)\n            queue = Queue(maxsize=2)\n            t = Thread(target=producer, args=(data, slicers, queue))\n            t.start()\n\n            # preallocate arrays\n            if self.verbose:\n                print(f'preallocating results arrays on device {results_device}')\n            predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),\n                                           dtype=torch.half,\n                                           device=results_device)\n            n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)\n\n            if self.use_gaussian:\n                gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,\n                                            value_scaling_factor=10,\n                                            device=results_device)\n            else:\n                gaussian = 1\n\n            if not self.allow_tqdm and self.verbose:\n                print(f'running prediction: {len(slicers)} steps')\n\n            with tqdm(desc=None, total=len(slicers), disable=not self.allow_tqdm) as pbar:\n                while True:\n                    item = queue.get()\n                    if item == 'end':\n                        queue.task_done()\n                        break\n                    workon, sl = item\n                    prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)\n\n                    if self.use_gaussian:\n                        prediction *= gaussian\n                    predicted_logits[sl] += prediction\n                    n_predictions[sl[1:]] += gaussian\n                    queue.task_done()\n                    pbar.update()\n            queue.join()\n\n            # predicted_logits /= n_predictions\n            torch.div(predicted_logits, n_predictions, out=predicted_logits)\n            # check for infs\n            if torch.any(torch.isinf(predicted_logits)):\n                raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '\n                                   'reduce value_scaling_factor in compute_gaussian or increase the dtype of '\n                                   'predicted_logits to fp32')\n        except Exception as e:\n            del predicted_logits, n_predictions, prediction, gaussian, workon\n            empty_cache(self.device)\n            empty_cache(results_device)\n            raise e\n        return predicted_logits\n\n    @torch.inference_mode()\n    def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \\\n            -> Union[np.ndarray, torch.Tensor]:\n        assert isinstance(input_image, torch.Tensor)\n        self.network = self.network.to(self.device)\n        self.network.eval()\n\n        empty_cache(self.device)\n\n        # Autocast can be annoying\n        # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)\n        # and needs to be disabled.\n        # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False\n        # is set. Whyyyyyyy. (this is why we don't make use of enabled=False)\n        # So autocast will only be active if we have a cuda device.\n        with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():\n            assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'\n\n            if self.verbose:\n                print(f'Input shape: {input_image.shape}')\n                print(\"step_size:\", self.tile_step_size)\n                print(\"mirror_axes:\", self.allowed_mirroring_axes if self.use_mirroring else None)\n\n            # if input_image is smaller than tile_size we need to pad it to tile_size.\n            data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,\n                                                       'constant', {'value': 0}, True,\n                                                       None)\n\n            slicers = self._internal_get_sliding_window_slicers(data.shape[1:])\n\n            if self.perform_everything_on_device and self.device != 'cpu':\n                # we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device\n                try:\n                    predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,\n                                                                                           self.perform_everything_on_device)\n                except RuntimeError:\n                    print(\n                        'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')\n                    empty_cache(self.device)\n                    predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)\n            else:\n                predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,\n                                                                                       self.perform_everything_on_device)\n\n            empty_cache(self.device)\n            # revert padding\n            predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]\n        return predicted_logits\n\n    def predict_from_files_sequential(self,\n                           list_of_lists_or_source_folder: Union[str, List[List[str]]],\n                           output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]],\n                           save_probabilities: bool = False,\n                           overwrite: bool = True,\n                           folder_with_segs_from_prev_stage: str = None):\n        \"\"\"\n        Just like predict_from_files but doesn't use any multiprocessing. Slow, but sometimes necessary\n        \"\"\"\n        if isinstance(output_folder_or_list_of_truncated_output_files, str):\n            output_folder = output_folder_or_list_of_truncated_output_files\n        elif isinstance(output_folder_or_list_of_truncated_output_files, list):\n            output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0])\n            if len(output_folder) == 0:  # just a file was given without a folder\n                output_folder = os.path.curdir\n        else:\n            output_folder = None\n\n        ########################\n        # let's store the input arguments so that its clear what was used to generate the prediction\n        if output_folder is not None:\n            my_init_kwargs = {}\n            for k in inspect.signature(self.predict_from_files_sequential).parameters.keys():\n                my_init_kwargs[k] = locals()[k]\n            my_init_kwargs = deepcopy(\n                my_init_kwargs)  # let's not unintentionally change anything in-place. Take this as a\n            recursive_fix_for_json_export(my_init_kwargs)\n            save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json'))\n\n            # we need these two if we want to do things with the predictions like for example apply postprocessing\n            save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)\n            save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False)\n        #######################\n\n        # check if we need a prediction from the previous stage\n        if self.configuration_manager.previous_stage_name is not None:\n            assert folder_with_segs_from_prev_stage is not None, \\\n                f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \\\n                f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \\\n                f' they are located via folder_with_segs_from_prev_stage'\n\n        # sort out input and output filenames\n        list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \\\n            self._manage_input_and_output_lists(list_of_lists_or_source_folder,\n                                                output_folder_or_list_of_truncated_output_files,\n                                                folder_with_segs_from_prev_stage, overwrite, 0, 1,\n                                                save_probabilities)\n        if len(list_of_lists_or_source_folder) == 0:\n            return\n\n        label_manager = self.plans_manager.get_label_manager(self.dataset_json)\n        preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose)\n\n        if output_filename_truncated is None:\n            output_filename_truncated = [None] * len(list_of_lists_or_source_folder)\n        if seg_from_prev_stage_files is None:\n            seg_from_prev_stage_files = [None] * len(seg_from_prev_stage_files)\n\n        ret = []\n        for li, of, sps in zip(list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files):\n            data, seg, data_properties = preprocessor.run_case(\n                li,\n                sps,\n                self.plans_manager,\n                self.configuration_manager,\n                self.dataset_json\n            )\n\n            print(f'perform_everything_on_device: {self.perform_everything_on_device}')\n\n            prediction = self.predict_logits_from_preprocessed_data(torch.from_numpy(data)).cpu()\n\n            if of is not None:\n                export_prediction_from_logits(prediction, data_properties, self.configuration_manager, self.plans_manager,\n                  self.dataset_json, of, save_probabilities)\n            else:\n                ret.append(convert_predicted_logits_to_segmentation_with_correct_shape(prediction, self.plans_manager,\n                     self.configuration_manager, self.label_manager,\n                     data_properties,\n                     save_probabilities))\n\n        # clear lru cache\n        compute_gaussian.cache_clear()\n        # clear device cache\n        empty_cache(self.device)\n        return ret\n\ndef _getDefaultValue(env: str, dtype: type, default: any,) -> any:\n    try:\n        val = dtype(os.environ.get(env) or default)\n    except:\n        val = default\n    return val\n\ndef predict_entry_point_modelfolder():\n    import argparse\n    parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when '\n                                                 'you want to manually specify a folder containing a trained nnU-Net '\n                                                 'model. This is useful when the nnunet environment variables '\n                                                 '(nnUNet_results) are not set.')\n    parser.add_argument('-i', type=str, required=True,\n                        help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). '\n                             'File endings must be the same as the training dataset!')\n    parser.add_argument('-o', type=str, required=True,\n                        help='Output folder. If it does not exist it will be created. Predicted segmentations will '\n                             'have the same name as their source images.')\n    parser.add_argument('-m', type=str, required=True,\n                        help='Folder in which the trained model is. Must have subfolders fold_X for the different '\n                             'folds you trained')\n    parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4),\n                        help='Specify the folds of the trained model that should be used for prediction. '\n                             'Default: (0, 1, 2, 3, 4)')\n    parser.add_argument('-step_size', type=float, required=False, default=0.5,\n                        help='Step size for sliding window prediction. The larger it is the faster but less accurate '\n                             'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.')\n    parser.add_argument('--disable_tta', action='store_true', required=False, default=False,\n                        help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '\n                             'but less accurate inference. Not recommended.')\n    parser.add_argument('--verbose', action='store_true', help=\"Set this if you like being talked to. You will have \"\n                                                               \"to be a good listener/reader.\")\n    parser.add_argument('--save_probabilities', action='store_true',\n                        help='Set this to export predicted class \"probabilities\". Required if you want to ensemble '\n                             'multiple configurations.')\n    parser.add_argument('--continue_prediction', '--c', action='store_true',\n                        help='Continue an aborted previous prediction (will not overwrite existing files)')\n    parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth',\n                        help='Name of the checkpoint you want to use. Default: checkpoint_final.pth')\n    parser.add_argument('-npp', type=int, required=False, default=3,\n                        help='Number of processes used for preprocessing. More is not always better. Beware of '\n                             'out-of-RAM issues. Default: 3')\n    parser.add_argument('-nps', type=int, required=False, default=3,\n                        help='Number of processes used for segmentation export. More is not always better. Beware of '\n                             'out-of-RAM issues. Default: 3')\n    parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,\n                        help='Folder containing the predictions of the previous stage. Required for cascaded models.')\n    parser.add_argument('-device', type=str, default='cuda', required=False,\n                        help=\"Use this to set the device the inference should run with. Available options are 'cuda' \"\n                             \"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! \"\n                             \"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!\")\n    parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False,\n                        help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive '\n                             'jobs)')\n    parser.add_argument('--not_on_device', action='store_true', required=False, default=False,\n                        help=\"Set this flag to disable perform_everything_on_device. Recommended for large cases that \"\n                             \"occupy more VRAM than available\")\n\n    print(\n        \"\\n#######################################################################\\nPlease cite the following paper \"\n        \"when using nnU-Net:\\n\"\n        \"Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). \"\n        \"nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. \"\n        \"Nature methods, 18(2), 203-211.\\n#######################################################################\\n\")\n\n    args = parser.parse_args()\n    args.f = [i if i == 'all' else int(i) for i in args.f]\n\n    if not isdir(args.o):\n        maybe_mkdir_p(args.o)\n\n    assert args.device in ['cpu', 'cuda',\n                           'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'\n    if args.device == 'cpu':\n        # let's allow torch to use hella threads\n        import multiprocessing\n        torch.set_num_threads(multiprocessing.cpu_count())\n        device = torch.device('cpu')\n    elif args.device == 'cuda':\n        # multithreading in torch doesn't help nnU-Net if run on GPU\n        torch.set_num_threads(1)\n        torch.set_num_interop_threads(1)\n        device = torch.device('cuda')\n    else:\n        device = torch.device('mps')\n\n    predictor = nnUNetPredictor(tile_step_size=args.step_size,\n                                use_gaussian=True,\n                                use_mirroring=not args.disable_tta,\n                                perform_everything_on_device=not args.not_on_device,\n                                device=device,\n                                verbose=args.verbose,\n                                allow_tqdm=not args.disable_progress_bar,\n                                verbose_preprocessing=args.verbose)\n    predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk)\n    predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities,\n                                 overwrite=not args.continue_prediction,\n                                 num_processes_preprocessing=args.npp,\n                                 num_processes_segmentation_export=args.nps,\n                                 folder_with_segs_from_prev_stage=args.prev_stage_predictions,\n                                 num_parts=1, part_id=0)\n\n\ndef predict_entry_point():\n    import argparse\n    parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when '\n                                                 'you want to manually specify a folder containing a trained nnU-Net '\n                                                 'model. This is useful when the nnunet environment variables '\n                                                 '(nnUNet_results) are not set.')\n    parser.add_argument('-i', type=str, required=True,\n                        help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). '\n                             'File endings must be the same as the training dataset!')\n    parser.add_argument('-o', type=str, required=True,\n                        help='Output folder. If it does not exist it will be created. Predicted segmentations will '\n                             'have the same name as their source images.')\n    parser.add_argument('-d', type=str, required=True,\n                        help='Dataset with which you would like to predict. You can specify either dataset name or id')\n    parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',\n                        help='Plans identifier. Specify the plans in which the desired configuration is located. '\n                             'Default: nnUNetPlans')\n    parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',\n                        help='What nnU-Net trainer class was used for training? Default: nnUNetTrainer')\n    parser.add_argument('-c', type=str, required=True,\n                        help='nnU-Net configuration that should be used for prediction. Config must be located '\n                             'in the plans specified with -p')\n    parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4),\n                        help='Specify the folds of the trained model that should be used for prediction. '\n                             'Default: (0, 1, 2, 3, 4)')\n    parser.add_argument('-step_size', type=float, required=False, default=0.5,\n                        help='Step size for sliding window prediction. The larger it is the faster but less accurate '\n                             'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.')\n    parser.add_argument('--disable_tta', action='store_true', required=False, default=False,\n                        help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '\n                             'but less accurate inference. Not recommended.')\n    parser.add_argument('--verbose', action='store_true', help=\"Set this if you like being talked to. You will have \"\n                                                               \"to be a good listener/reader.\")\n    parser.add_argument('--save_probabilities', action='store_true',\n                        help='Set this to export predicted class \"probabilities\". Required if you want to ensemble '\n                             'multiple configurations.')\n    parser.add_argument('--continue_prediction', action='store_true',\n                        help='Continue an aborted previous prediction (will not overwrite existing files)')\n    parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth',\n                        help='Name of the checkpoint you want to use. Default: checkpoint_final.pth')\n    parser.add_argument('-npp', type=int, required=False, default=_getDefaultValue('nnUNet_npp', int, 3),\n                        help='Number of processes used for preprocessing. More is not always better. Beware of '\n                             'out-of-RAM issues. Default: 3')\n    parser.add_argument('-nps', type=int, required=False, default=_getDefaultValue('nnUNet_nps', int, 3),\n                        help='Number of processes used for segmentation export. More is not always better. Beware of '\n                             'out-of-RAM issues. Default: 3')\n    parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,\n                        help='Folder containing the predictions of the previous stage. Required for cascaded models.')\n    parser.add_argument('-num_parts', type=int, required=False, default=1,\n                        help='Number of separate nnUNetv2_predict call that you will be making. Default: 1 (= this one '\n                             'call predicts everything)')\n    parser.add_argument('-part_id', type=int, required=False, default=0,\n                        help='If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 can end with '\n                             'num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set -num_parts '\n                             '5 and use -part_id 0, 1, 2, 3 and 4. Simple, right? Note: You are yourself responsible '\n                             'to make these run on separate GPUs! Use CUDA_VISIBLE_DEVICES (google, yo!)')\n    parser.add_argument('-device', type=str, default='cuda', required=False,\n                        help=\"Use this to set the device the inference should run with. Available options are 'cuda' \"\n                             \"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! \"\n                             \"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!\")\n    parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False,\n                        help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive '\n                             'jobs)')\n    parser.add_argument('--not_on_device', action='store_true', required=False, default=False,\n                        help=\"Set this flag to disable perform_everything_on_device. Recommended for large cases that \"\n                             \"occupy more VRAM than available\")\n\n    print(\n        \"\\n#######################################################################\\nPlease cite the following paper \"\n        \"when using nnU-Net:\\n\"\n        \"Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). \"\n        \"nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. \"\n        \"Nature methods, 18(2), 203-211.\\n#######################################################################\\n\")\n\n    args = parser.parse_args()\n    args.f = [i if i == 'all' else int(i) for i in args.f]\n\n    model_folder = get_output_folder(args.d, args.tr, args.p, args.c)\n\n    if not isdir(args.o):\n        maybe_mkdir_p(args.o)\n\n    # slightly passive aggressive haha\n    assert args.part_id < args.num_parts, 'Do you even read the documentation? See nnUNetv2_predict -h.'\n\n    assert args.device in ['cpu', 'cuda',\n                           'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'\n    if args.device == 'cpu':\n        # let's allow torch to use hella threads\n        import multiprocessing\n        torch.set_num_threads(multiprocessing.cpu_count())\n        device = torch.device('cpu')\n    elif args.device == 'cuda':\n        # multithreading in torch doesn't help nnU-Net if run on GPU\n        torch.set_num_threads(1)\n        torch.set_num_interop_threads(1)\n        device = torch.device('cuda')\n    else:\n        device = torch.device('mps')\n\n    predictor = nnUNetPredictor(tile_step_size=args.step_size,\n                                use_gaussian=True,\n                                use_mirroring=not args.disable_tta,\n                                perform_everything_on_device=not args.not_on_device,\n                                device=device,\n                                verbose=args.verbose,\n                                verbose_preprocessing=args.verbose,\n                                allow_tqdm=not args.disable_progress_bar)\n    predictor.initialize_from_trained_model_folder(\n        model_folder,\n        args.f,\n        checkpoint_name=args.chk\n    )\n    \n    run_sequential = args.nps == 0 and args.npp == 0\n    \n    if run_sequential:\n        \n        print(\"Running in non-multiprocessing mode\")\n        predictor.predict_from_files_sequential(args.i, args.o, save_probabilities=args.save_probabilities,\n                                                overwrite=not args.continue_prediction,\n                                                folder_with_segs_from_prev_stage=args.prev_stage_predictions)\n    \n    else:\n        \n        predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities,\n                                    overwrite=not args.continue_prediction,\n                                    num_processes_preprocessing=args.npp,\n                                    num_processes_segmentation_export=args.nps,\n                                    folder_with_segs_from_prev_stage=args.prev_stage_predictions,\n                                    num_parts=args.num_parts,\n                                    part_id=args.part_id)\n    \n    # r = predict_from_raw_data(args.i,\n    #                           args.o,\n    #                           model_folder,\n    #                           args.f,\n    #                           args.step_size,\n    #                           use_gaussian=True,\n    #                           use_mirroring=not args.disable_tta,\n    #                           perform_everything_on_device=True,\n    #                           verbose=args.verbose,\n    #                           save_probabilities=args.save_probabilities,\n    #                           overwrite=not args.continue_prediction,\n    #                           checkpoint_name=args.chk,\n    #                           num_processes_preprocessing=args.npp,\n    #                           num_processes_segmentation_export=args.nps,\n    #                           folder_with_segs_from_prev_stage=args.prev_stage_predictions,\n    #                           num_parts=args.num_parts,\n    #                           part_id=args.part_id,\n    #                           device=device)\n\n\nif __name__ == '__main__':\n    ########################## predict a bunch of files\n    from nnunetv2.paths import nnUNet_results, nnUNet_raw\n\n    predictor = nnUNetPredictor(\n        tile_step_size=0.5,\n        use_gaussian=True,\n        use_mirroring=True,\n        perform_everything_on_device=True,\n        device=torch.device('cuda', 0),\n        verbose=False,\n        verbose_preprocessing=False,\n        allow_tqdm=True\n    )\n    predictor.initialize_from_trained_model_folder(\n        join(nnUNet_results, 'Dataset004_Hippocampus/nnUNetTrainer_5epochs__nnUNetPlans__3d_fullres'),\n        use_folds=(0,),\n        checkpoint_name='checkpoint_final.pth',\n    )\n    # predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),\n    #                              join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),\n    #                              save_probabilities=False, overwrite=False,\n    #                              num_processes_preprocessing=2, num_processes_segmentation_export=2,\n    #                              folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)\n    #\n    # # predict a numpy array\n    # from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO\n    #\n    # img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')])\n    # ret = predictor.predict_single_npy_array(img, props, None, None, False)\n    #\n    # iterator = predictor.get_data_iterator_from_raw_npy_data([img], None, [props], None, 1)\n    # ret = predictor.predict_from_data_iterator(iterator, False, 1)\n\n    ret = predictor.predict_from_files_sequential(\n        [['/media/isensee/raw_data/nnUNet_raw/Dataset004_Hippocampus/imagesTs/hippocampus_002_0000.nii.gz'], ['/media/isensee/raw_data/nnUNet_raw/Dataset004_Hippocampus/imagesTs/hippocampus_005_0000.nii.gz']],\n        '/home/isensee/temp/tmp', False, True, None\n    )\n\n\n"
  },
  {
    "path": "nnunetv2/inference/readme.md",
    "content": "The nnU-Net inference is now much more dynamic than before, allowing you to more seamlessly integrate nnU-Net into \nyour existing workflows.\nThis readme will give you a quick rundown of your options. This is not a complete guide. Look into the code to learn \nall the details!\n\n# Preface\nIn terms of speed, the most efficient inference strategy is the one done by the nnU-Net defaults! Images are read on \nthe fly and preprocessed in background workers. The main process takes the preprocessed images, predicts them and \nsends the prediction off to another set of background workers which will resize the resulting logits, convert \nthem to a segmentation and export the segmentation.\n\nThe reason the default setup is the best option is because \n\n1) loading and preprocessing as well as segmentation export are interlaced with the prediction. The main process can \nfocus on communicating with the compute device (i.e. your GPU) and does not have to do any other processing. \nThis uses your resources as well as possible!\n2) only the images and segmentation that are currently being needed are stored in RAM! Imaging predicting many images \nand having to store all of them + the results in your system memory\n\n# nnUNetPredictor\nThe new nnUNetPredictor class encapsulates the inferencing code and makes it simple to switch between modes. Your \ncode can hold a nnUNetPredictor instance and perform prediction on the fly. Previously this was not possible and each \nnew prediction request resulted in reloading the parameters and reinstantiating the network architecture. Not ideal.\n\nThe nnUNetPredictor must be ininitialized manually! You will want to use the \n`predictor.initialize_from_trained_model_folder` function for 99% of use cases!\n\nNew feature: If you do not specify an output folder / output files then the predicted segmentations will be \nreturned \n\n\n## Recommended nnU-Net default: predict from source files\n\ntldr:\n- loads images on the fly\n- performs preprocessing in background workers\n- main process focuses only on making predictions\n- results are again given to background workers for resampling and (optional) export\n\npros:\n- best suited for predicting a large number of images\n- nicer to your RAM\n\ncons:\n- not ideal when single images are to be predicted \n- requires images to be present as files\n\nExample:\n```python\n    from nnunetv2.paths import nnUNet_results, nnUNet_raw\n    import torch\n    from batchgenerators.utilities.file_and_folder_operations import join\n    from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\n    \n    # instantiate the nnUNetPredictor\n    predictor = nnUNetPredictor(\n        tile_step_size=0.5,\n        use_gaussian=True,\n        use_mirroring=True,\n        perform_everything_on_device=True,\n        device=torch.device('cuda', 0),\n        verbose=False,\n        verbose_preprocessing=False,\n        allow_tqdm=True\n    )\n    # initializes the network architecture, loads the checkpoint\n    predictor.initialize_from_trained_model_folder(\n        join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'),\n        use_folds=(0,),\n        checkpoint_name='checkpoint_final.pth',\n    )\n    # variant 1: give input and output folders\n    predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),\n                                 join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),\n                                 save_probabilities=False, overwrite=False,\n                                 num_processes_preprocessing=2, num_processes_segmentation_export=2,\n                                 folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)\n```\n\nInstead if giving input and output folders you can also give concrete files. If you give concrete files, there is no \nneed for the _0000 suffix anymore! This can be useful in situations where you have no control over the filenames!\nRemember that the files must be given as 'list of lists' where each entry in the outer list is a case to be predicted \nand the inner list contains all the files belonging to that case. There is just one file for datasets with just one \ninput modality (such as CT) but may be more files for others (such as MRI where there is sometimes T1, T2, Flair etc). \nIMPORTANT: the order in which the files for each case are given must match the order of the channels as defined in the \ndataset.json!\n\nIf you give files as input, you need to give individual output files as output!\n\n```python\n    # variant 2, use list of files as inputs. Note how we use nested lists!!!\n    indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs')\n    outdir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres')\n    predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], \n                                  [join(indir, 'liver_142_0000.nii.gz')]],\n                                 [join(outdir, 'liver_152'),\n                                  join(outdir, 'liver_142')],\n                                 save_probabilities=False, overwrite=False,\n                                 num_processes_preprocessing=2, num_processes_segmentation_export=2,\n                                 folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)\n```\n\nDid you know? If you do not specify output files, the predicted segmentations will be returned:\n```python\n    # variant 2.5, returns segmentations\n    indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs')\n    predicted_segmentations = predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')],\n                                  [join(indir, 'liver_142_0000.nii.gz')]],\n                                 None,\n                                 save_probabilities=False, overwrite=True,\n                                 num_processes_preprocessing=2, num_processes_segmentation_export=2,\n                                 folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)\n```\n\n## Prediction from npy arrays\ntldr:\n- you give images as a list of npy arrays\n- performs preprocessing in background workers\n- main process focuses only on making predictions\n- results are again given to background workers for resampling and (optional) export\n\npros:\n- the correct variant for when you have images in RAM already\n- well suited for predicting multiple images\n\ncons:\n- uses more ram than the default\n- unsuited for large number of images as all images must be held in RAM\n\n```python\n    from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO\n\n    img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])\n    img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')])\n    img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')])\n    img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')])\n    # we do not set output files so that the segmentations will be returned. You can of course also specify output\n    # files instead (no return value on that case)\n    ret = predictor.predict_from_list_of_npy_arrays([img, img2, img3, img4],\n                                                    None,\n                                                    [props, props2, props3, props4],\n                                                    None, 2, save_probabilities=False,\n                                                    num_processes_segmentation_export=2)\n```\n\n## Predicting a single npy array\n\ntldr:\n- you give one image as npy array\n- axes ordering must match the corresponding training data. The easiest way to achieve that is to use the same I/O class\n                     for loading images as was used during nnU-Net preprocessing! You can find that class in your\n                     plans.json file under the key \"image_reader_writer\". If you decide to freestyle, know that the\n                     default axis ordering for medical images is the one from SimpleITK. If you load with nibabel,\n                     you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]!\n- everything is done in the main process: preprocessing, prediction, resampling, (export)\n- no interlacing, slowest variant!\n- ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON\n\npros:\n- no messing with multiprocessing\n- no messing with data iterator blabla\n\ncons:\n- slows as heck, yo\n- never the right choice unless you can only give a single image at a time to nnU-Net\n\n```python\n    # predict a single numpy array (SimpleITKIO)\n    img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')])\n    ret = predictor.predict_single_npy_array(img, props, None, None, False)\n\n    # predict a single numpy array (NibabelIO)\n    img, props = NibabelIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')])\n    ret = predictor.predict_single_npy_array(img, props, None, None, False)\n\n    # The following IS NOT RECOMMENDED. Use nnunetv2.imageio!\n    # nibabel, we need to transpose axes and spacing to match the training axes ordering for the nnU-Net default:\n    nib.load('Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')\n    img = np.asanyarray(img_nii.dataobj).transpose([2, 1, 0])  # reverse axis order to match SITK\n    props = {'spacing': img_nii.header.get_zooms()[::-1]}      # reverse axis order to match SITK\n    ret = predictor.predict_single_npy_array(img, props, None, None, False)\n```\n\n## Predicting with a custom data iterator\ntldr: \n- highly flexible\n- not for newbies\n\npros:\n- you can do everything yourself\n- you have all the freedom you want\n- really fast if you remember to use multiprocessing in your iterator\n\ncons:\n- you need to do everything yourself\n- harder than you might think\n\n```python\n    img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])\n    img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')])\n    img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')])\n    img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')])\n    # each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys!\n    # If 'ofile' is None, the result will be returned instead of written to a file\n    # the iterator is responsible for performing the correct preprocessing!\n    # note how the iterator here does not use multiprocessing -> preprocessing will be done in the main thread!\n    # take a look at the default iterators for predict_from_files and predict_from_list_of_npy_arrays\n    # (they both use predictor.predict_from_data_iterator) for inspiration!\n    def my_iterator(list_of_input_arrs, list_of_input_props):\n        preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose)\n        for a, p in zip(list_of_input_arrs, list_of_input_props):\n            data, seg = preprocessor.run_case_npy(a,\n                                                  None,\n                                                  p,\n                                                  predictor.plans_manager,\n                                                  predictor.configuration_manager,\n                                                  predictor.dataset_json)\n            yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properties': p, 'ofile': None}\n    ret = predictor.predict_from_data_iterator(my_iterator([img, img2, img3, img4], [props, props2, props3, props4]),\n                                               save_probabilities=False, num_processes_segmentation_export=3)\n```\n"
  },
  {
    "path": "nnunetv2/inference/sliding_window_prediction.py",
    "content": "from functools import lru_cache\n\nimport numpy as np\nimport torch\nfrom typing import Union, Tuple, List\nfrom acvl_utils.cropping_and_padding.padding import pad_nd_image\nfrom scipy.ndimage import gaussian_filter\n\n\n@lru_cache(maxsize=2)\ndef compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8,\n                     value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \\\n        -> torch.Tensor:\n    tmp = np.zeros(tile_size)\n    center_coords = [i // 2 for i in tile_size]\n    sigmas = [i * sigma_scale for i in tile_size]\n    tmp[tuple(center_coords)] = 1\n    gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)\n\n    gaussian_importance_map = torch.from_numpy(gaussian_importance_map)\n\n    gaussian_importance_map /= (torch.max(gaussian_importance_map) / value_scaling_factor)\n    gaussian_importance_map = gaussian_importance_map.to(device=device, dtype=dtype)\n    # gaussian_importance_map cannot be 0, otherwise we may end up with nans!\n    mask = gaussian_importance_map == 0\n    gaussian_importance_map[mask] = torch.min(gaussian_importance_map[~mask])\n    return gaussian_importance_map\n\n\ndef compute_steps_for_sliding_window(image_size: Tuple[int, ...], tile_size: Tuple[int, ...], tile_step_size: float) -> \\\n        List[List[int]]:\n    assert [i >= j for i, j in zip(image_size, tile_size)], \"image size must be as large or larger than patch_size\"\n    assert 0 < tile_step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1'\n\n    # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of\n    # 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46\n    target_step_sizes_in_voxels = [i * tile_step_size for i in tile_size]\n\n    num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, tile_size)]\n\n    steps = []\n    for dim in range(len(tile_size)):\n        # the highest step value for this dimension is\n        max_step_value = image_size[dim] - tile_size[dim]\n        if num_steps[dim] > 1:\n            actual_step_size = max_step_value / (num_steps[dim] - 1)\n        else:\n            actual_step_size = 99999999999  # does not matter because there is only one step at 0\n\n        steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])]\n\n        steps.append(steps_here)\n\n    return steps\n\n\nif __name__ == '__main__':\n    a = torch.rand((4, 2, 32, 23))\n    a_npy = a.numpy()\n\n    a_padded = pad_nd_image(a, new_shape=(48, 27))\n    a_npy_padded = pad_nd_image(a_npy, new_shape=(48, 27))\n    assert all([i == j for i, j in zip(a_padded.shape, (4, 2, 48, 27))])\n    assert all([i == j for i, j in zip(a_npy_padded.shape, (4, 2, 48, 27))])\n    assert np.all(a_padded.numpy() == a_npy_padded)\n"
  },
  {
    "path": "nnunetv2/model_sharing/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/model_sharing/entry_points.py",
    "content": "from nnunetv2.model_sharing.model_download import download_and_install_from_url\nfrom nnunetv2.model_sharing.model_export import export_pretrained_model\nfrom nnunetv2.model_sharing.model_import import install_model_from_zip_file\n\n\ndef print_license_warning():\n    print('')\n    print('######################################################')\n    print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!')\n    print('######################################################')\n    print(\"Using the pretrained model weights is subject to the license of the dataset they were trained on. Some \"\n          \"allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use \"\n          \"nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!\")\n    print('######################################################')\n    print('')\n\n\ndef download_by_url():\n    import argparse\n    parser = argparse.ArgumentParser(\n        description=\"Use this to download pretrained models. This script is intended to download models via url only. \"\n                    \"CAREFUL: This script will overwrite \"\n                    \"existing models (if they share the same trainer class and plans as \"\n                    \"the pretrained model.\")\n    parser.add_argument(\"url\", type=str, help='URL of the pretrained model')\n    args = parser.parse_args()\n    url = args.url\n    download_and_install_from_url(url)\n\n\ndef install_from_zip_entry_point():\n    import argparse\n    parser = argparse.ArgumentParser(\n        description=\"Use this to install a zip file containing a pretrained model.\")\n    parser.add_argument(\"zip\", type=str, help='zip file')\n    args = parser.parse_args()\n    zip = args.zip\n    install_model_from_zip_file(zip)\n\n\ndef export_pretrained_model_entry():\n    import argparse\n    parser = argparse.ArgumentParser(\n        description=\"Use this to export a trained model as a zip file.\")\n    parser.add_argument('-d', type=str, required=True, help='Dataset name or id')\n    parser.add_argument('-o', type=str, required=True, help='Output file name')\n    parser.add_argument('-c', nargs='+', type=str, required=False,\n                        default=('3d_lowres', '3d_fullres', '2d', '3d_cascade_fullres'),\n                        help=\"List of configuration names\")\n    parser.add_argument('-tr', required=False, type=str, default='nnUNetTrainer', help='Trainer class')\n    parser.add_argument('-p', required=False, type=str, default='nnUNetPlans', help='plans identifier')\n    parser.add_argument('-f', required=False, nargs='+', type=str, default=(0, 1, 2, 3, 4), help='list of fold ids')\n    parser.add_argument('-chk', required=False, nargs='+', type=str, default=('checkpoint_final.pth', ),\n                        help='Lis tof checkpoint names to export. Default: checkpoint_final.pth')\n    parser.add_argument('--not_strict', action='store_false', default=False, required=False, help='Set this to allow missing folds and/or configurations')\n    parser.add_argument('--exp_cv_preds', action='store_true', required=False, help='Set this to export the cross-validation predictions as well')\n    args = parser.parse_args()\n\n    export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr,\n                            plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk,\n                            export_crossval_predictions=args.exp_cv_preds)\n"
  },
  {
    "path": "nnunetv2/model_sharing/model_download.py",
    "content": "from typing import Optional\n\nimport requests\nfrom batchgenerators.utilities.file_and_folder_operations import *\nfrom time import time\nfrom nnunetv2.model_sharing.model_import import install_model_from_zip_file\nfrom nnunetv2.paths import nnUNet_results\nfrom tqdm import tqdm\n\n\ndef download_and_install_from_url(url):\n    assert nnUNet_results is not None, \"Cannot install model because network_training_output_dir is not \" \\\n                                                    \"set (RESULTS_FOLDER missing as environment variable, see \" \\\n                                                    \"Installation instructions)\"\n    print('Downloading pretrained model from url:', url)\n    import http.client\n    http.client.HTTPConnection._http_vsn = 10\n    http.client.HTTPConnection._http_vsn_str = 'HTTP/1.0'\n\n    import os\n    home = os.path.expanduser('~')\n    random_number = int(time() * 1e7)\n    tempfile = join(home, f'.nnunetdownload_{str(random_number)}')\n\n    try:\n        download_file(url=url, local_filename=tempfile, chunk_size=8192 * 16)\n        print(\"Download finished. Extracting...\")\n        install_model_from_zip_file(tempfile)\n        print(\"Done\")\n    except Exception as e:\n        raise e\n    finally:\n        if isfile(tempfile):\n            os.remove(tempfile)\n\n\ndef download_file(url: str, local_filename: str, chunk_size: Optional[int] = 8192 * 16) -> str:\n    # borrowed from https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests\n    # NOTE the stream=True parameter below\n    with requests.get(url, stream=True, timeout=100) as r:\n        r.raise_for_status()\n        with tqdm.wrapattr(open(local_filename, 'wb'), \"write\", total=int(r.headers.get(\"Content-Length\"))) as f:\n            for chunk in r.iter_content(chunk_size=chunk_size):\n                f.write(chunk)\n    return local_filename\n\n\n"
  },
  {
    "path": "nnunetv2/model_sharing/model_export.py",
    "content": "import zipfile\n\nfrom nnunetv2.utilities.file_path_utilities import *\n\n\ndef export_pretrained_model(dataset_name_or_id: Union[int, str], output_file: str,\n                            configurations: Tuple[str] = (\"2d\", \"3d_lowres\", \"3d_fullres\", \"3d_cascade_fullres\"),\n                            trainer: str = 'nnUNetTrainer',\n                            plans_identifier: str = 'nnUNetPlans',\n                            folds: Tuple[int, ...] = (0, 1, 2, 3, 4),\n                            strict: bool = True,\n                            save_checkpoints: Tuple[str, ...] = ('checkpoint_final.pth',),\n                            export_crossval_predictions: bool = False) -> None:\n    dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)\n    with(zipfile.ZipFile(output_file, 'w', zipfile.ZIP_DEFLATED)) as zipf:\n        for c in configurations:\n            print(f\"Configuration {c}\")\n            trainer_output_dir = get_output_folder(dataset_name, trainer, plans_identifier, c)\n\n            if not isdir(trainer_output_dir):\n                if strict:\n                    raise RuntimeError(f\"{dataset_name} is missing the trained model of configuration {c}\")\n                else:\n                    continue\n\n            expected_fold_folder = [f\"fold_{i}\" if i != 'all' else 'fold_all' for i in folds]\n            assert all([isdir(join(trainer_output_dir, i)) for i in expected_fold_folder]), \\\n                f\"not all requested folds are present; {dataset_name} {c}; requested folds: {folds}\"\n\n            assert isfile(join(trainer_output_dir, \"plans.json\")), f\"plans.json missing, {dataset_name} {c}\"\n\n            for fold_folder in expected_fold_folder:\n                print(f\"Exporting {fold_folder}\")\n                # debug.json, does not exist yet\n                source_file = join(trainer_output_dir, fold_folder, \"debug.json\")\n                if isfile(source_file):\n                    zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))\n\n                # all requested checkpoints\n                for chk in save_checkpoints:\n                    source_file = join(trainer_output_dir, fold_folder, chk)\n                    zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))\n\n                # progress.png\n                source_file = join(trainer_output_dir, fold_folder, \"progress.png\")\n                zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))\n\n                # if it exists, network architecture.png\n                source_file = join(trainer_output_dir, fold_folder, \"network_architecture.pdf\")\n                if isfile(source_file):\n                    zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))\n\n                # validation folder with all predicted segmentations etc\n                if export_crossval_predictions:\n                    source_folder = join(trainer_output_dir, fold_folder, \"validation\")\n                    files = [i for i in subfiles(source_folder, join=False) if not i.endswith('.npz') and not i.endswith('.pkl')]\n                    for f in files:\n                        zipf.write(join(source_folder, f), os.path.relpath(join(source_folder, f), nnUNet_results))\n                # just the summary.json file from the validation\n                else:\n                    source_file = join(trainer_output_dir, fold_folder, \"validation\", \"summary.json\")\n                    zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))\n\n            source_folder = join(trainer_output_dir, f'crossval_results_folds_{folds_tuple_to_string(folds)}')\n            if isdir(source_folder):\n                if export_crossval_predictions:\n                    source_files = subfiles(source_folder, join=True)\n                else:\n                    source_files = [\n                        join(trainer_output_dir, f'crossval_results_folds_{folds_tuple_to_string(folds)}', i) for i in\n                        ['summary.json', 'postprocessing.pkl', 'postprocessing.json']\n                    ]\n                for s in source_files:\n                    if isfile(s):\n                        zipf.write(s, os.path.relpath(s, nnUNet_results))\n            # plans\n            source_file = join(trainer_output_dir, \"plans.json\")\n            zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))\n            # fingerprint\n            source_file = join(trainer_output_dir, \"dataset_fingerprint.json\")\n            zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))\n            # dataset\n            source_file = join(trainer_output_dir, \"dataset.json\")\n            zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))\n\n        ensemble_dir = join(nnUNet_results, dataset_name, 'ensembles')\n\n        if not isdir(ensemble_dir):\n            print(\"No ensemble directory found for task\", dataset_name_or_id)\n            return\n        subd = subdirs(ensemble_dir, join=False)\n                # figure out whether the models in the ensemble are all within the exported models here\n        for ens in subd:\n            identifiers, folds = convert_ensemble_folder_to_model_identifiers_and_folds(ens)\n            ok = True\n            for i in identifiers:\n                tr, pl, c = convert_identifier_to_trainer_plans_config(i)\n                if tr == trainer and pl == plans_identifier and c in configurations:\n                    pass\n                else:\n                    ok = False\n            if ok:\n                print(f'found matching ensemble: {ens}')\n                source_folder = join(ensemble_dir, ens)\n                if export_crossval_predictions:\n                    source_files = subfiles(source_folder, join=True)\n                else:\n                    source_files = [\n                        join(source_folder, i) for i in\n                        ['summary.json', 'postprocessing.pkl', 'postprocessing.json'] if isfile(join(source_folder, i))\n                    ]\n                for s in source_files:\n                    zipf.write(s, os.path.relpath(s, nnUNet_results))\n        inference_information_file = join(nnUNet_results, dataset_name, 'inference_information.json')\n        if isfile(inference_information_file):\n            zipf.write(inference_information_file, os.path.relpath(inference_information_file, nnUNet_results))\n        inference_information_txt_file = join(nnUNet_results, dataset_name, 'inference_information.txt')\n        if isfile(inference_information_txt_file):\n            zipf.write(inference_information_txt_file, os.path.relpath(inference_information_txt_file, nnUNet_results))\n    print('Done')\n\n\nif __name__ == '__main__':\n    export_pretrained_model(2, '/home/fabian/temp/dataset2.zip', strict=False, export_crossval_predictions=True, folds=(0, ))\n"
  },
  {
    "path": "nnunetv2/model_sharing/model_import.py",
    "content": "import zipfile\n\nfrom nnunetv2.paths import nnUNet_results\n\n\ndef install_model_from_zip_file(zip_file: str):\n    with zipfile.ZipFile(zip_file, 'r') as zip_ref:\n        zip_ref.extractall(nnUNet_results)"
  },
  {
    "path": "nnunetv2/paths.py",
    "content": "#    Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\nimport os\n\n\"\"\"\nPLEASE READ documentation/setting_up_paths.md FOR INFORMATION TO HOW TO SET THIS UP\n\"\"\"\n\nnnUNet_raw = os.environ.get('nnUNet_raw')\nnnUNet_preprocessed = os.environ.get('nnUNet_preprocessed')\nnnUNet_results = os.environ.get('nnUNet_results')\n\nif nnUNet_raw is None:\n    print(\"nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files \"\n          \"are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like \"\n          \"this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set \"\n          \"this up properly.\")\n\nif nnUNet_preprocessed is None:\n    print(\"nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing \"\n          \"or training. If this is not intended, please read documentation/setting_up_paths.md for information on how \"\n          \"to set this up.\")\n\nif nnUNet_results is None:\n    print(\"nnUNet_results is not defined and nnU-Net cannot be used for training or \"\n          \"inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information \"\n          \"on how to set this up.\")\n"
  },
  {
    "path": "nnunetv2/postprocessing/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/postprocessing/remove_connected_components.py",
    "content": "import argparse\nimport multiprocessing\nimport shutil\nfrom typing import Union, Tuple, List, Callable\n\nimport numpy as np\nfrom acvl_utils.morphology.morphology_helper import remove_all_but_largest_component\nfrom batchgenerators.utilities.file_and_folder_operations import load_json, subfiles, maybe_mkdir_p, join, isfile, \\\n    isdir, save_pickle, load_pickle, save_json\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results\nfrom nnunetv2.evaluation.evaluate_predictions import region_or_label_to_mask, compute_metrics_on_folder, \\\n    load_summary_json, label_or_region_to_key\nfrom nnunetv2.imageio.base_reader_writer import BaseReaderWriter\nfrom nnunetv2.paths import nnUNet_raw\nfrom nnunetv2.utilities.file_path_utilities import folds_tuple_to_string\nfrom nnunetv2.utilities.json_export import recursive_fix_for_json_export\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager\n\n\ndef remove_all_but_largest_component_from_segmentation(segmentation: np.ndarray,\n                                                       labels_or_regions: Union[int, Tuple[int, ...],\n                                                                                List[Union[int, Tuple[int, ...]]]],\n                                                       background_label: int = 0) -> np.ndarray:\n    mask = np.zeros_like(segmentation, dtype=bool)\n    if not isinstance(labels_or_regions, list):\n        labels_or_regions = [labels_or_regions]\n    for l_or_r in labels_or_regions:\n        mask |= region_or_label_to_mask(segmentation, l_or_r)\n    mask_keep = remove_all_but_largest_component(mask)\n    ret = np.copy(segmentation)  # do not modify the input!\n    ret[mask & ~mask_keep] = background_label\n    return ret\n\n\ndef apply_postprocessing(segmentation: np.ndarray, pp_fns: List[Callable], pp_fn_kwargs: List[dict]):\n    for fn, kwargs in zip(pp_fns, pp_fn_kwargs):\n        segmentation = fn(segmentation, **kwargs)\n    return segmentation\n\n\ndef load_postprocess_save(segmentation_file: str,\n                          output_fname: str,\n                          image_reader_writer: BaseReaderWriter,\n                          pp_fns: List[Callable],\n                          pp_fn_kwargs: List[dict]):\n    seg, props = image_reader_writer.read_seg(segmentation_file)\n    seg = apply_postprocessing(seg[0], pp_fns, pp_fn_kwargs)\n    image_reader_writer.write_seg(seg, output_fname, props)\n\n\ndef determine_postprocessing(folder_predictions: str,\n                             folder_ref: str,\n                             plans_file_or_dict: Union[str, dict],\n                             dataset_json_file_or_dict: Union[str, dict],\n                             num_processes: int = default_num_processes,\n                             keep_postprocessed_files: bool = True):\n    \"\"\"\n    Determines nnUNet postprocessing. Its output is a postprocessing.pkl file in folder_predictions which can be\n    used with apply_postprocessing_to_folder.\n\n    Postprocessed files are saved in folder_predictions/postprocessed. Set\n    keep_postprocessed_files=False to delete these files after this function is done (temp files will eb created\n    and deleted regardless).\n\n    If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder\n    \"\"\"\n    output_folder = join(folder_predictions, 'postprocessed')\n\n    if plans_file_or_dict is None:\n        expected_plans_file = join(folder_predictions, 'plans.json')\n        if not isfile(expected_plans_file):\n            raise RuntimeError(f\"Expected plans file missing: {expected_plans_file}. The plans files should have been \"\n                               f\"created while running nnUNetv2_predict. Sadge.\")\n        plans_file_or_dict = load_json(expected_plans_file)\n    plans_manager = PlansManager(plans_file_or_dict)\n\n    if dataset_json_file_or_dict is None:\n        expected_dataset_json_file = join(folder_predictions, 'dataset.json')\n        if not isfile(expected_dataset_json_file):\n            raise RuntimeError(\n                f\"Expected plans file missing: {expected_dataset_json_file}. The plans files should have been \"\n                f\"created while running nnUNetv2_predict. Sadge.\")\n        dataset_json_file_or_dict = load_json(expected_dataset_json_file)\n\n    if not isinstance(dataset_json_file_or_dict, dict):\n        dataset_json = load_json(dataset_json_file_or_dict)\n    else:\n        dataset_json = dataset_json_file_or_dict\n\n    rw = plans_manager.image_reader_writer_class()\n    label_manager = plans_manager.get_label_manager(dataset_json)\n    labels_or_regions = label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels\n\n    predicted_files = subfiles(folder_predictions, suffix=dataset_json['file_ending'], join=False)\n    ref_files = subfiles(folder_ref, suffix=dataset_json['file_ending'], join=False)\n    # we should print a warning if not all files from folder_ref are present in folder_predictions\n    if not all([i in predicted_files for i in ref_files]):\n        print(f'WARNING: Not all files in folder_ref were found in folder_predictions. Determining postprocessing '\n              f'should always be done on the entire dataset!')\n\n    # before we start we should evaluate the imaegs in the source folder\n    if not isfile(join(folder_predictions, 'summary.json')):\n        compute_metrics_on_folder(folder_ref,\n                                  folder_predictions,\n                                  join(folder_predictions, 'summary.json'),\n                                  rw,\n                                  dataset_json['file_ending'],\n                                  labels_or_regions,\n                                  label_manager.ignore_label,\n                                  num_processes)\n\n    # we save the postprocessing functions in here\n    pp_fns = []\n    pp_fn_kwargs = []\n\n    # pool party!\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as pool:\n        # now let's see whether removing all but the largest foreground region improves the scores\n        output_here = join(output_folder, 'temp', 'keep_largest_fg')\n        maybe_mkdir_p(output_here)\n        pp_fn = remove_all_but_largest_component_from_segmentation\n        kwargs = {\n            'labels_or_regions': label_manager.foreground_labels,\n        }\n\n        pool.starmap(\n            load_postprocess_save,\n            zip(\n                [join(folder_predictions, i) for i in predicted_files],\n                [join(output_here, i) for i in predicted_files],\n                [rw] * len(predicted_files),\n                [[pp_fn]] * len(predicted_files),\n                [[kwargs]] * len(predicted_files)\n            )\n        )\n        compute_metrics_on_folder(folder_ref,\n                                  output_here,\n                                  join(output_here, 'summary.json'),\n                                  rw,\n                                  dataset_json['file_ending'],\n                                  labels_or_regions,\n                                  label_manager.ignore_label,\n                                  num_processes)\n        # now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far\n        # that if a single class got worse as a result we won't do this. We can change this in the future but right now I\n        # prefer to do it this way\n        baseline_results = load_summary_json(join(folder_predictions, 'summary.json'))\n        pp_results = load_summary_json(join(output_here, 'summary.json'))\n        do_this = pp_results['foreground_mean']['Dice'] > baseline_results['foreground_mean']['Dice']\n        if do_this:\n            for class_id in pp_results['mean'].keys():\n                if pp_results['mean'][class_id]['Dice'] < baseline_results['mean'][class_id]['Dice']:\n                    do_this = False\n                    break\n        if do_this:\n            print(f'Results were improved by removing all but the largest foreground region. '\n                  f'Mean dice before: {round(baseline_results[\"foreground_mean\"][\"Dice\"], 5)} '\n                  f'after: {round(pp_results[\"foreground_mean\"][\"Dice\"], 5)}')\n            source = output_here\n            pp_fns.append(pp_fn)\n            pp_fn_kwargs.append(kwargs)\n        else:\n            print(f'Removing all but the largest foreground region did not improve results!')\n            source = folder_predictions\n\n        # in the old nnU-Net we could just apply all-but-largest component removal to all classes at the same time and\n        # then evaluate for each class whether this improved results. This is no longer possible because we now support\n        # region-based predictions and regions can overlap, causing interactions\n        # in principle the order with which the postprocessing is applied to the regions matter as well and should be\n        # investigated, but due to some things that I am too lazy to explain right now it's going to be alright (I think)\n        # to stick to the order in which they are declared in dataset.json (if you want to think about it then think about\n        # region_class_order)\n        # 2023_02_06: I hate myself for the comment above. Thanks past me\n        if len(labels_or_regions) > 1:\n            for label_or_region in labels_or_regions:\n                pp_fn = remove_all_but_largest_component_from_segmentation\n                kwargs = {\n                    'labels_or_regions': label_or_region,\n                }\n\n                output_here = join(output_folder, 'temp', 'keep_largest_perClassOrRegion')\n                maybe_mkdir_p(output_here)\n\n                pool.starmap(\n                    load_postprocess_save,\n                    zip(\n                        [join(source, i) for i in predicted_files],\n                        [join(output_here, i) for i in predicted_files],\n                        [rw] * len(predicted_files),\n                        [[pp_fn]] * len(predicted_files),\n                        [[kwargs]] * len(predicted_files)\n                    )\n                )\n                compute_metrics_on_folder(folder_ref,\n                                          output_here,\n                                          join(output_here, 'summary.json'),\n                                          rw,\n                                          dataset_json['file_ending'],\n                                          labels_or_regions,\n                                          label_manager.ignore_label,\n                                          num_processes)\n                baseline_results = load_summary_json(join(source, 'summary.json'))\n                pp_results = load_summary_json(join(output_here, 'summary.json'))\n                do_this = pp_results['mean'][label_or_region]['Dice'] > baseline_results['mean'][label_or_region]['Dice']\n                if do_this:\n                    print(f'Results were improved by removing all but the largest component for {label_or_region}. '\n                          f'Dice before: {round(baseline_results[\"mean\"][label_or_region][\"Dice\"], 5)} '\n                          f'after: {round(pp_results[\"mean\"][label_or_region][\"Dice\"], 5)}')\n                    if isdir(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')):\n                        shutil.rmtree(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'))\n                    shutil.move(output_here, join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'), )\n                    source = join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')\n                    pp_fns.append(pp_fn)\n                    pp_fn_kwargs.append(kwargs)\n                else:\n                    print(f'Removing all but the largest component for {label_or_region} did not improve results! '\n                          f'Dice before: {round(baseline_results[\"mean\"][label_or_region][\"Dice\"], 5)} '\n                          f'after: {round(pp_results[\"mean\"][label_or_region][\"Dice\"], 5)}')\n    [shutil.copy(join(source, i), join(output_folder, i)) for i in subfiles(source, join=False)]\n    save_pickle((pp_fns, pp_fn_kwargs), join(folder_predictions, 'postprocessing.pkl'))\n\n    baseline_results = load_summary_json(join(folder_predictions, 'summary.json'))\n    final_results = load_summary_json(join(output_folder, 'summary.json'))\n    tmp = {\n        'input_folder': {i: baseline_results[i] for i in ['foreground_mean', 'mean']},\n        'postprocessed': {i: final_results[i] for i in ['foreground_mean', 'mean']},\n        'postprocessing_fns': [i.__name__ for i in pp_fns],\n        'postprocessing_kwargs': pp_fn_kwargs,\n    }\n    # json is very annoying. Can't handle tuples as dict keys.\n    tmp['input_folder']['mean'] = {label_or_region_to_key(k): tmp['input_folder']['mean'][k] for k in\n                                   tmp['input_folder']['mean'].keys()}\n    tmp['postprocessed']['mean'] = {label_or_region_to_key(k): tmp['postprocessed']['mean'][k] for k in\n                                    tmp['postprocessed']['mean'].keys()}\n    # did I already say that I hate json? \"TypeError: Object of type int64 is not JSON serializable\"\n    recursive_fix_for_json_export(tmp)\n    save_json(tmp, join(folder_predictions, 'postprocessing.json'))\n\n    shutil.rmtree(join(output_folder, 'temp'))\n\n    if not keep_postprocessed_files:\n        shutil.rmtree(output_folder)\n    return pp_fns, pp_fn_kwargs\n\n\ndef apply_postprocessing_to_folder(input_folder: str,\n                                   output_folder: str,\n                                   pp_fns: List[Callable],\n                                   pp_fn_kwargs: List[dict],\n                                   plans_file_or_dict: Union[str, dict] = None,\n                                   dataset_json_file_or_dict: Union[str, dict] = None,\n                                   num_processes=8) -> None:\n    \"\"\"\n    If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder\n    \"\"\"\n    if plans_file_or_dict is None:\n        expected_plans_file = join(input_folder, 'plans.json')\n        if not isfile(expected_plans_file):\n            raise RuntimeError(f\"Expected plans file missing: {expected_plans_file}. The plans file should have been \"\n                               f\"created while running nnUNetv2_predict. Sadge. If the folder you want to apply \"\n                               f\"postprocessing to was create from an ensemble then just specify one of the \"\n                               f\"plans files of the ensemble members in plans_file_or_dict\")\n        plans_file_or_dict = load_json(expected_plans_file)\n    plans_manager = PlansManager(plans_file_or_dict)\n\n    if dataset_json_file_or_dict is None:\n        expected_dataset_json_file = join(input_folder, 'dataset.json')\n        if not isfile(expected_dataset_json_file):\n            raise RuntimeError(\n                f\"Expected plans file missing: {expected_dataset_json_file}. The dataset.json should have been \"\n                f\"copied while running nnUNetv2_predict/nnUNetv2_ensemble. Sadge.\")\n        dataset_json_file_or_dict = load_json(expected_dataset_json_file)\n\n    if not isinstance(dataset_json_file_or_dict, dict):\n        dataset_json = load_json(dataset_json_file_or_dict)\n    else:\n        dataset_json = dataset_json_file_or_dict\n\n    rw = plans_manager.image_reader_writer_class()\n\n    maybe_mkdir_p(output_folder)\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as p:\n        files = subfiles(input_folder, suffix=dataset_json['file_ending'], join=False)\n\n        _ = p.starmap(load_postprocess_save,\n                      zip(\n                          [join(input_folder, i) for i in files],\n                          [join(output_folder, i) for i in files],\n                          [rw] * len(files),\n                          [pp_fns] * len(files),\n                          [pp_fn_kwargs] * len(files)\n                      )\n                      )\n\n\ndef entry_point_determine_postprocessing_folder():\n    parser = argparse.ArgumentParser('Writes postprocessing.pkl and postprocessing.json in input_folder.')\n    parser.add_argument('-i', type=str, required=True, help='Input folder')\n    parser.add_argument('-ref', type=str, required=True, help='Folder with gt labels')\n    parser.add_argument('-plans_json', type=str, required=False, default=None,\n                        help=\"plans file to use. If not specified we will look for the plans.json file in the \"\n                             \"input folder (input_folder/plans.json)\")\n    parser.add_argument('-dataset_json', type=str, required=False, default=None,\n                        help=\"dataset.json file to use. If not specified we will look for the dataset.json file in the \"\n                             \"input folder (input_folder/dataset.json)\")\n    parser.add_argument('-np', type=int, required=False, default=default_num_processes,\n                        help=f\"number of processes to use. Default: {default_num_processes}\")\n    parser.add_argument('--remove_postprocessed', action='store_true', required=False,\n                        help='set this is you don\\'t want to keep the postprocessed files')\n\n    args = parser.parse_args()\n    determine_postprocessing(args.i, args.ref, args.plans_json, args.dataset_json, args.np,\n                             not args.remove_postprocessed)\n\n\ndef entry_point_apply_postprocessing():\n    parser = argparse.ArgumentParser('Apples postprocessing specified in pp_pkl_file to input folder.')\n    parser.add_argument('-i', type=str, required=True, help='Input folder')\n    parser.add_argument('-o', type=str, required=True, help='Output folder')\n    parser.add_argument('-pp_pkl_file', type=str, required=True, help='postprocessing.pkl file')\n    parser.add_argument('-np', type=int, required=False, default=default_num_processes,\n                        help=f\"number of processes to use. Default: {default_num_processes}\")\n    parser.add_argument('-plans_json', type=str, required=False, default=None,\n                        help=\"plans file to use. If not specified we will look for the plans.json file in the \"\n                             \"input folder (input_folder/plans.json)\")\n    parser.add_argument('-dataset_json', type=str, required=False, default=None,\n                        help=\"dataset.json file to use. If not specified we will look for the dataset.json file in the \"\n                             \"input folder (input_folder/dataset.json)\")\n    args = parser.parse_args()\n    pp_fns, pp_fn_kwargs = load_pickle(args.pp_pkl_file)\n    apply_postprocessing_to_folder(args.i, args.o, pp_fns, pp_fn_kwargs, args.plans_json, args.dataset_json, args.np)\n\n\nif __name__ == '__main__':\n    trained_model_folder = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__3d_fullres'\n    labelstr = join(nnUNet_raw, 'Dataset004_Hippocampus', 'labelsTr')\n    plans_manager = PlansManager(join(trained_model_folder, 'plans.json'))\n    dataset_json = load_json(join(trained_model_folder, 'dataset.json'))\n    folds = (0, 1, 2, 3, 4)\n    label_manager = plans_manager.get_label_manager(dataset_json)\n\n    merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}')\n    accumulate_cv_results(trained_model_folder, merged_output_folder, folds, 8, False)\n\n    fns, kwargs = determine_postprocessing(merged_output_folder, labelstr, plans_manager.plans,\n                                           dataset_json, 8, keep_postprocessed_files=True)\n    save_pickle((fns, kwargs), join(trained_model_folder, 'postprocessing.pkl'))\n    fns, kwargs = load_pickle(join(trained_model_folder, 'postprocessing.pkl'))\n\n    apply_postprocessing_to_folder(merged_output_folder, merged_output_folder + '_pp', fns, kwargs,\n                                   plans_manager.plans, dataset_json,\n                                   8)\n    compute_metrics_on_folder(labelstr,\n                              merged_output_folder + '_pp',\n                              join(merged_output_folder + '_pp', 'summary.json'),\n                              plans_manager.image_reader_writer_class(),\n                              dataset_json['file_ending'],\n                              label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels,\n                              label_manager.ignore_label,\n                              8)\n"
  },
  {
    "path": "nnunetv2/preprocessing/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/preprocessing/cropping/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/preprocessing/cropping/cropping.py",
    "content": "import numpy as np\nfrom scipy.ndimage import binary_fill_holes\nfrom acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, bounding_box_to_slice\n\n\ndef create_nonzero_mask(data):\n    \"\"\"\n\n    :param data:\n    :return: the mask is True where the data is nonzero\n    \"\"\"\n    assert data.ndim in (3, 4), \"data must have shape (C, X, Y, Z) or shape (C, X, Y)\"\n    nonzero_mask = data[0] != 0\n    for c in range(1, data.shape[0]):\n        nonzero_mask |= data[c] != 0\n    return binary_fill_holes(nonzero_mask)\n\n\ndef crop_to_nonzero(data, seg=None, nonzero_label=-1):\n    \"\"\"\n\n    :param data:\n    :param seg:\n    :param nonzero_label: this will be written into the segmentation map\n    :return:\n    \"\"\"\n    nonzero_mask = create_nonzero_mask(data)\n    bbox = get_bbox_from_mask(nonzero_mask)\n    slicer = bounding_box_to_slice(bbox)\n    nonzero_mask = nonzero_mask[slicer][None]\n    \n    slicer = (slice(None), ) + slicer\n    data = data[slicer]\n    if seg is not None:\n        seg = seg[slicer]\n        seg[(seg == 0) & (~nonzero_mask)] = nonzero_label\n    else:\n        seg = np.where(nonzero_mask, np.int8(0), np.int8(nonzero_label))\n    return data, seg, bbox\n\n\n"
  },
  {
    "path": "nnunetv2/preprocessing/normalization/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/preprocessing/normalization/default_normalization_schemes.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Type\n\nimport numpy as np\nfrom numpy import number\n\n\nclass ImageNormalization(ABC):\n    leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = None\n\n    def __init__(self, use_mask_for_norm: bool = None, intensityproperties: dict = None,\n                 target_dtype: Type[number] = np.float32):\n        assert use_mask_for_norm is None or isinstance(use_mask_for_norm, bool)\n        self.use_mask_for_norm = use_mask_for_norm\n        assert isinstance(intensityproperties, dict)\n        self.intensityproperties = intensityproperties\n        self.target_dtype = target_dtype\n\n    @abstractmethod\n    def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:\n        \"\"\"\n        Image and seg must have the same shape. Seg is not always used\n        \"\"\"\n        pass\n\n\nclass ZScoreNormalization(ImageNormalization):\n    leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = True\n\n    def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:\n        \"\"\"\n        here seg is used to store the zero valued region. The value for that region in the segmentation is -1 by\n        default.\n        \"\"\"\n        eps = 1e-8 if not self.target_dtype == np.float16 else 1e-4\n        image = image.astype(self.target_dtype, copy=False)\n        if self.use_mask_for_norm:\n            assert seg is not None, (\"use_mask_for_norm is set, please provide a mask for the nonzero areas of the \"\n                                     \"image via seg. The mask will be computed as `mask = seg >= 0`. You can use \"\n                                     \"create_nonzero_mask from nnunetv2/preprocessing/cropping\")\n        if seg is not None and self.use_mask_for_norm:\n            # negative values in the segmentation encode the 'outside' region (think zero values around the brain as\n            # in BraTS). We want to run the normalization only in the brain region, so we need to mask the image.\n            # The default nnU-net sets use_mask_for_norm to True if cropping to the nonzero region substantially\n            # reduced the image size.\n            mask = seg >= 0\n            mean = image[mask].mean()\n            std = image[mask].std()\n            image[mask] = (image[mask] - mean) / (max(std, eps))\n        else:\n            mean = image.mean()\n            std = image.std()\n            image -= mean\n            image /= (max(std, eps))\n        return image\n\n\nclass CTNormalization(ImageNormalization):\n    leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False\n\n    def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:\n        assert self.intensityproperties is not None, \"CTNormalization requires intensity properties\"\n        eps = 1e-8 if not self.target_dtype == np.float16 else 1e-4\n        mean_intensity = self.intensityproperties['mean']\n        std_intensity = self.intensityproperties['std']\n        lower_bound = self.intensityproperties['percentile_00_5']\n        upper_bound = self.intensityproperties['percentile_99_5']\n\n        image = image.astype(self.target_dtype, copy=False)\n        np.clip(image, lower_bound, upper_bound, out=image)\n        image -= mean_intensity\n        image /= max(std_intensity, eps)\n        return image\n\n\nclass NoNormalization(ImageNormalization):\n    leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False\n\n    def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:\n        return image.astype(self.target_dtype, copy=False)\n\n\nclass RescaleTo01Normalization(ImageNormalization):\n    leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False\n\n    def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:\n        eps = 1e-8 if not self.target_dtype == np.float16 else 1e-4\n        image = image.astype(self.target_dtype, copy=False)\n        image -= image.min()\n        image /= np.clip(image.max(), a_min=eps, a_max=None)\n        return image\n\n\nclass RGBTo01Normalization(ImageNormalization):\n    leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False\n\n    def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:\n        assert image.min() >= 0, \"RGB images are uint 8, for whatever reason I found pixel values smaller than 0. \" \\\n                                 \"Your images do not seem to be RGB images\"\n        assert image.max() <= 255, \"RGB images are uint 8, for whatever reason I found pixel values greater than 255\" \\\n                                   \". Your images do not seem to be RGB images\"\n        image = image.astype(self.target_dtype, copy=False)\n        image /= 255.\n        return image\n\n"
  },
  {
    "path": "nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py",
    "content": "from typing import Type\n\nfrom nnunetv2.preprocessing.normalization.default_normalization_schemes import CTNormalization, NoNormalization, \\\n    ZScoreNormalization, RescaleTo01Normalization, RGBTo01Normalization, ImageNormalization\n\nchannel_name_to_normalization_mapping = {\n    'ct': CTNormalization,\n    'nonorm': NoNormalization,\n    'zscore': ZScoreNormalization,\n    'rescale_to_0_1': RescaleTo01Normalization,\n    'rgb_to_0_1': RGBTo01Normalization\n}\n\n\ndef get_normalization_scheme(channel_name: str) -> Type[ImageNormalization]:\n    \"\"\"\n    If we find the channel_name in channel_name_to_normalization_mapping return the corresponding normalization. If it is\n    not found, use the default (ZScoreNormalization)\n    \"\"\"\n    norm_scheme = channel_name_to_normalization_mapping.get(channel_name.casefold())\n    if norm_scheme is None:\n        norm_scheme = ZScoreNormalization\n    # print('Using %s for image normalization' % norm_scheme.__name__)\n    return norm_scheme\n"
  },
  {
    "path": "nnunetv2/preprocessing/normalization/readme.md",
    "content": "The channel_names entry in dataset.json only determines the normlaization scheme. So if you want to use something different \nthen you can just\n- create a new subclass of ImageNormalization\n- map your custom channel identifier to that subclass in channel_name_to_normalization_mapping\n- run plan and preprocess again with your custom normlaization scheme"
  },
  {
    "path": "nnunetv2/preprocessing/preprocessors/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/preprocessing/preprocessors/default_preprocessor.py",
    "content": "#    Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\nimport multiprocessing\nimport shutil\nfrom time import sleep\nfrom typing import Tuple\nfrom typing import Union\n\nimport SimpleITK\nimport numpy as np\nfrom batchgenerators.utilities.file_and_folder_operations import *\nfrom tqdm import tqdm\n\nimport nnunetv2\nfrom nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw\nfrom nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero\nfrom nnunetv2.preprocessing.resampling.default_resampling import compute_new_shape\nfrom nnunetv2.training.dataloading.nnunet_dataset import nnUNetDatasetBlosc2\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\nfrom nnunetv2.utilities.find_class_by_name import recursive_find_python_class\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager\nfrom nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets\n\n\nclass DefaultPreprocessor(object):\n    def __init__(self, verbose: bool = True):\n        self.verbose = verbose\n        self.show_progress_bar = True\n        \"\"\"\n        Everything we need is in the plans. Those are given when run() is called\n        \"\"\"\n\n    def run_case_npy(self, data: np.ndarray, seg: Union[np.ndarray, None], properties: dict,\n                     plans_manager: PlansManager, configuration_manager: ConfigurationManager,\n                     dataset_json: Union[dict, str]):\n        # let's not mess up the inputs!\n        data = data.astype(np.float32)  # this creates a copy\n        if seg is not None:\n            assert data.shape[1:] == seg.shape[1:], \"Shape mismatch between image and segmentation. Please fix your dataset and make use of the --verify_dataset_integrity flag to ensure everything is correct\"\n            seg = np.copy(seg)\n\n        has_seg = seg is not None\n\n        # apply transpose_forward, this also needs to be applied to the spacing!\n        data = data.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]])\n        if seg is not None:\n            seg = seg.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]])\n        original_spacing = [properties['spacing'][i] for i in plans_manager.transpose_forward]\n\n        # crop, remember to store size before cropping!\n        shape_before_cropping = data.shape[1:]\n        properties['shape_before_cropping'] = shape_before_cropping\n        # this command will generate a segmentation. This is important because of the nonzero mask which we may need\n        data, seg, bbox = crop_to_nonzero(data, seg)\n        properties['bbox_used_for_cropping'] = bbox\n        # print(data.shape, seg.shape)\n        properties['shape_after_cropping_and_before_resampling'] = data.shape[1:]\n\n        # resample\n        target_spacing = configuration_manager.spacing  # this should already be transposed\n\n        if len(target_spacing) < len(data.shape[1:]):\n            # target spacing for 2d has 2 entries but the data and original_spacing have three because everything is 3d\n            # in 2d configuration we do not change the spacing between slices\n            target_spacing = [original_spacing[0]] + target_spacing\n        new_shape = compute_new_shape(data.shape[1:], original_spacing, target_spacing)\n\n        # normalize\n        # normalization MUST happen before resampling or we get huge problems with resampled nonzero masks no\n        # longer fitting the images perfectly!\n        data = self._normalize(data, seg, configuration_manager,\n                               plans_manager.foreground_intensity_properties_per_channel)\n\n        # print('current shape', data.shape[1:], 'current_spacing', original_spacing,\n        #       '\\ntarget shape', new_shape, 'target_spacing', target_spacing)\n        old_shape = data.shape[1:]\n        data = configuration_manager.resampling_fn_data(data, new_shape, original_spacing, target_spacing)\n        seg = configuration_manager.resampling_fn_seg(seg, new_shape, original_spacing, target_spacing)\n        if self.verbose:\n            print(f'old shape: {old_shape}, new_shape: {new_shape}, old_spacing: {original_spacing}, '\n                  f'new_spacing: {target_spacing}, fn_data: {configuration_manager.resampling_fn_data}')\n\n        # if we have a segmentation, sample foreground locations for oversampling and add those to properties\n        if has_seg:\n            # reinstantiating LabelManager for each case is not ideal. We could replace the dataset_json argument\n            # with a LabelManager Instance in this function because that's all its used for. Dunno what's better.\n            # LabelManager is pretty light computation-wise.\n            label_manager = plans_manager.get_label_manager(dataset_json)\n            collect_for_this = label_manager.foreground_regions if label_manager.has_regions \\\n                else label_manager.foreground_labels\n\n            # when using the ignore label we want to sample only from annotated regions. Therefore we also need to\n            # collect samples uniformly from all classes (incl background)\n            if label_manager.has_ignore_label:\n                collect_for_this.append([-1] + label_manager.all_labels)\n\n            # no need to filter background in regions because it is already filtered in handle_labels\n            # print(all_labels, regions)\n            properties['class_locations'] = self._sample_foreground_locations(seg, collect_for_this,\n                                                                                   verbose=self.verbose)\n            seg = self.modify_seg_fn(seg, plans_manager, dataset_json, configuration_manager)\n        if np.max(seg) > 127:\n            seg = seg.astype(np.int16)\n        else:\n            seg = seg.astype(np.int8)\n        return data, seg, properties\n\n    def run_case(self, image_files: List[str], seg_file: Union[str, None], plans_manager: PlansManager,\n                 configuration_manager: ConfigurationManager,\n                 dataset_json: Union[dict, str]):\n        \"\"\"\n        seg file can be none (test cases)\n\n        order of operations is: transpose -> crop -> resample\n        so when we export we need to run the following order: resample -> crop -> transpose (we could also run\n        transpose at a different place, but reverting the order of operations done during preprocessing seems cleaner)\n        \"\"\"\n        if isinstance(dataset_json, str):\n            dataset_json = load_json(dataset_json)\n\n        rw = plans_manager.image_reader_writer_class()\n\n        # load image(s)\n        data, data_properties = rw.read_images(image_files)\n\n        # if possible, load seg\n        if seg_file is not None:\n            seg, _ = rw.read_seg(seg_file)\n        else:\n            seg = None\n\n        if self.verbose:\n            print(seg_file)\n        data, seg, data_properties = self.run_case_npy(data, seg, data_properties, plans_manager, configuration_manager,\n                                      dataset_json)\n        return data, seg, data_properties\n\n    def run_case_save(self, output_filename_truncated: str, image_files: List[str], seg_file: str,\n                      plans_manager: PlansManager, configuration_manager: ConfigurationManager,\n                      dataset_json: Union[dict, str]):\n        data, seg, properties = self.run_case(image_files, seg_file, plans_manager, configuration_manager, dataset_json)\n        data = data.astype(np.float32, copy=False)\n        seg = seg.astype(np.int16, copy=False)\n        # print('dtypes', data.dtype, seg.dtype)\n        block_size_data, chunk_size_data = nnUNetDatasetBlosc2.comp_blosc2_params(\n            data.shape,\n            tuple(configuration_manager.patch_size),\n            data.itemsize)\n        block_size_seg, chunk_size_seg = nnUNetDatasetBlosc2.comp_blosc2_params(\n            seg.shape,\n            tuple(configuration_manager.patch_size),\n            seg.itemsize)\n\n        nnUNetDatasetBlosc2.save_case(data, seg, properties, output_filename_truncated,\n                                      chunks=chunk_size_data, blocks=block_size_data,\n                                      chunks_seg=chunk_size_seg, blocks_seg=block_size_seg)\n\n    @staticmethod\n    def _sample_foreground_locations(\n            seg: np.ndarray,\n            classes_or_regions: Union[List[int], List[Tuple[int, ...]]],\n            seed: int = 1234,\n            verbose: bool = False,\n            min_num_samples=10000,\n            min_percent_coverage = 0.01\n    ):\n\n        rndst = np.random.RandomState(seed)\n\n        class_locs = {}\n\n        # Normalize requested labels and compute the set of all labels we might need\n        normalized = []\n        requested_labels = set()\n        for c in classes_or_regions:\n            if isinstance(c, (tuple, list)):\n                labs = tuple(int(x) for x in c)\n                normalized.append(labs)\n                requested_labels.update(labs)\n            else:\n                lab = int(c)\n                normalized.append(lab)\n                requested_labels.add(lab)\n\n        # Create mask for all requested labels (this includes 0 if requested)\n        requested_labels_arr = np.fromiter(requested_labels, dtype=np.int32)\n        valid_mask = np.isin(seg, requested_labels_arr)\n\n        coords = np.argwhere(valid_mask)\n        seg_sel = seg[valid_mask]\n        del valid_mask\n\n        n = seg_sel.size\n        if n == 0:\n            for c in classes_or_regions:\n                k = tuple(c) if isinstance(c, (tuple, list)) else int(c)\n                class_locs[k] = []\n            return class_locs\n\n        # sort once, then compute label blocks\n        order = np.argsort(seg_sel, kind=\"stable\")\n        lab_sorted = seg_sel[order]\n        coords_sorted = coords[order]\n\n        change = np.flatnonzero(lab_sorted[1:] != lab_sorted[:-1]) + 1\n        starts = np.r_[0, change]\n        ends = np.r_[change, n]\n        labels_present = lab_sorted[starts]\n\n        label_to_range = {int(l): (int(s), int(e)) for l, s, e in zip(labels_present, starts, ends)}\n        present_labels = set(label_to_range.keys())\n\n        for c in classes_or_regions:\n            is_region = isinstance(c, (tuple, list))\n            labs = tuple(int(x) for x in c) if is_region else (int(c),)\n            k = labs if is_region else labs[0]\n\n            # Skip if none of the labels are present\n            if not any(lab in present_labels for lab in labs):\n                class_locs[k] = []\n                continue\n\n            # Collect ranges for present labels in this class/region\n            ranges = []\n            counts = []\n            for lab in labs:\n                r = label_to_range.get(lab)\n                if r is None:\n                    continue\n                s, e = r\n                cnt = e - s\n                if cnt > 0:\n                    ranges.append((s, e))\n                    counts.append(cnt)\n\n            if len(counts) == 0:\n                class_locs[k] = []\n                continue\n\n            total = int(np.sum(counts))\n            target_num_samples = min(min_num_samples, total)\n            target_num_samples = max(target_num_samples, int(np.ceil(total * min_percent_coverage)))\n\n            # Sample uniformly without replacement from the union of ranges, without building an n-sized mask\n            # Draw target_num_samples unique offsets in [0, total)\n            offsets = rndst.choice(total, target_num_samples, replace=False)\n\n            # Map offsets -> (range index, in-range offset) using cumulative counts\n            cum = np.cumsum(counts)\n            which = np.searchsorted(cum, offsets, side=\"right\")\n            prev = np.concatenate(([0], cum[:-1]))\n            in_range = offsets - prev[which]\n\n            # Convert to indices in coords_sorted\n            starts_for_pick = np.fromiter((ranges[i][0] for i in which), dtype=np.int64, count=which.size)\n            picked_idx = starts_for_pick + in_range.astype(np.int64)\n\n            selected = coords_sorted[picked_idx]\n            class_locs[k] = selected\n\n            if verbose:\n                print(c, target_num_samples)\n\n        return class_locs\n\n    # @staticmethod\n    # def _sample_foreground_locations(seg: np.ndarray, classes_or_regions: Union[List[int], List[Tuple[int, ...]]],\n    #                                  seed: int = 1234, verbose: bool = False):\n    #     num_samples = 10000\n    #     min_percent_coverage = 0.01  # at least 1% of the class voxels need to be selected, otherwise it may be too\n    #     # sparse\n    #     rndst = np.random.RandomState(seed)\n    #     class_locs = {}\n    #     foreground_mask = seg != 0\n    #     foreground_coords = np.argwhere(foreground_mask)\n    #     seg = seg[foreground_mask]\n    #     del foreground_mask\n    #     unique_labels = pd.unique(seg.ravel())\n    #\n    #     # We don't need more than 1e7 foreground samples. That's insanity. Cap here\n    #     if len(foreground_coords) > 1e7:\n    #         take_every = math.floor(len(foreground_coords) / 1e7)\n    #         # keep computation time reasonable\n    #         if verbose:\n    #             print(f'Subsampling foreground pixels 1:{take_every} for computational reasons')\n    #         foreground_coords = foreground_coords[::take_every]\n    #         seg = seg[::take_every]\n    #\n    #     for c in classes_or_regions:\n    #         k = c if not isinstance(c, list) else tuple(c)\n    #\n    #         # check if any of the labels are in seg, if not skip c\n    #         if isinstance(c, (tuple, list)):\n    #             if not any([ci in unique_labels for ci in c]):\n    #                 class_locs[k] = []\n    #                 continue\n    #         else:\n    #             if c not in unique_labels:\n    #                 class_locs[k] = []\n    #                 continue\n    #\n    #         if isinstance(c, (tuple, list)):\n    #             mask = seg == c[0]\n    #             for cc in c[1:]:\n    #                 mask = mask | (seg == cc)\n    #             all_locs = foreground_coords[mask]\n    #         else:\n    #             mask = seg == c\n    #             all_locs = foreground_coords[mask]\n    #         if len(all_locs) == 0:\n    #             class_locs[k] = []\n    #             continue\n    #         target_num_samples = min(num_samples, len(all_locs))\n    #         target_num_samples = max(target_num_samples, int(np.ceil(len(all_locs) * min_percent_coverage)))\n    #\n    #         selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)]\n    #         class_locs[k] = selected\n    #         if verbose:\n    #             print(c, target_num_samples)\n    #         seg = seg[~mask]\n    #         foreground_coords = foreground_coords[~mask]\n    #     return class_locs\n\n    def _normalize(self, data: np.ndarray, seg: np.ndarray, configuration_manager: ConfigurationManager,\n                   foreground_intensity_properties_per_channel: dict) -> np.ndarray:\n        for c in range(data.shape[0]):\n            scheme = configuration_manager.normalization_schemes[c]\n            normalizer_class = recursive_find_python_class(join(nnunetv2.__path__[0], \"preprocessing\", \"normalization\"),\n                                                           scheme,\n                                                           'nnunetv2.preprocessing.normalization')\n            if normalizer_class is None:\n                raise RuntimeError(f'Unable to locate class \\'{scheme}\\' for normalization')\n            normalizer = normalizer_class(use_mask_for_norm=configuration_manager.use_mask_for_norm[c],\n                                          intensityproperties=foreground_intensity_properties_per_channel[str(c)])\n            data[c] = normalizer.run(data[c], seg[0])\n        return data\n\n    def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plans_identifier: str,\n            num_processes: int):\n        \"\"\"\n        data identifier = configuration name in plans. EZ.\n        \"\"\"\n        dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)\n\n        assert isdir(join(nnUNet_raw, dataset_name)), \"The requested dataset could not be found in nnUNet_raw\"\n\n        plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json')\n        assert isfile(plans_file), \"Expected plans file (%s) not found. Run corresponding nnUNet_plan_experiment \" \\\n                                   \"first.\" % plans_file\n        plans = load_json(plans_file)\n        plans_manager = PlansManager(plans)\n        configuration_manager = plans_manager.get_configuration(configuration_name)\n\n        dataset_json_file = join(nnUNet_preprocessed, dataset_name, 'dataset.json')\n        dataset_json = load_json(dataset_json_file)\n\n        output_directory = join(nnUNet_preprocessed, dataset_name, configuration_manager.data_identifier)\n\n        if isdir(output_directory):\n            shutil.rmtree(output_directory)\n\n        maybe_mkdir_p(output_directory)\n\n        dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json)\n\n        # identifiers = [os.path.basename(i[:-len(dataset_json['file_ending'])]) for i in seg_fnames]\n        # output_filenames_truncated = [join(output_directory, i) for i in identifiers]\n\n        # multiprocessing magic.\n        r = []\n        with multiprocessing.get_context(\"spawn\").Pool(num_processes) as p:\n            remaining = list(range(len(dataset)))\n            # p is pretty nifti. If we kill workers they just respawn but don't do any work.\n            # So we need to store the original pool of workers.\n            workers = [j for j in p._pool]\n            for k in dataset.keys():\n                r.append(p.starmap_async(self.run_case_save,\n                                         ((join(output_directory, k), dataset[k]['images'], dataset[k]['label'],\n                                           plans_manager, configuration_manager,\n                                           dataset_json),)))\n\n            with tqdm(desc=\"Preprocessing cases\", total=len(dataset),\n                      disable=not getattr(self, 'show_progress_bar', True)) as pbar:\n                while len(remaining) > 0:\n                    all_alive = all([j.is_alive() for j in workers])\n                    if not all_alive:\n                        raise RuntimeError('Some background worker is 6 feet under. Yuck. \\n'\n                                           'OK jokes aside.\\n'\n                                           'One of your background processes is missing. This could be because of '\n                                           'an error (look for an error message) or because it was killed '\n                                           'by your OS due to running out of RAM. If you don\\'t see '\n                                           'an error message, out of RAM is likely the problem. In that case '\n                                           'reducing the number of workers might help')\n                    done = [i for i in remaining if r[i].ready()]\n                    for i in done:\n                        r[i].get()  # trigger any errors from worker (single call, no duplicate)\n                        pbar.update()\n                    remaining = [i for i in remaining if i not in done]\n                    sleep(0.1)\n\n    def modify_seg_fn(self, seg: np.ndarray, plans_manager: PlansManager, dataset_json: dict,\n                      configuration_manager: ConfigurationManager) -> np.ndarray:\n        # this function will be called at the end of self.run_case. Can be used to change the segmentation\n        # after resampling. Useful for experimenting with sparse annotations: I can introduce sparsity after resampling\n        # and don't have to create a new dataset each time I modify my experiments\n        return seg\n\n\ndef example_test_case_preprocessing():\n    # (paths to files may need adaptations)\n    plans_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/nnUNetPlans.json'\n    dataset_json_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/dataset.json'\n    input_images = ['/home/isensee/drives/e132-rohdaten/nnUNetv2/Dataset219_AMOS2022_postChallenge_task2/imagesTr/amos_0600_0000.nii.gz', ]  # if you only have one channel, you still need a list: ['case000_0000.nii.gz']\n\n    configuration = '3d_fullres'\n    pp = DefaultPreprocessor()\n\n    # _ because this position would be the segmentation if seg_file was not None (training case)\n    # even if you have the segmentation, don't put the file there! You should always evaluate in the original\n    # resolution. What comes out of the preprocessor might have been resampled to some other image resolution (as\n    # specified by plans)\n    plans_manager = PlansManager(plans_file)\n    data, _, properties = pp.run_case(input_images, seg_file=None, plans_manager=plans_manager,\n                                      configuration_manager=plans_manager.get_configuration(configuration),\n                                      dataset_json=dataset_json_file)\n\n    # voila. Now plug data into your prediction function of choice. We of course recommend nnU-Net's default (TODO)\n    return data\n\ndef _verify_class_locations(shape, outfile, class_locs):\n    import numpy as np\n    import SimpleITK as sitk\n\n    out = np.zeros(shape, dtype=np.uint16)  # allow many labels safely\n\n    for i, k in enumerate(class_locs.keys()):\n        class_coords = class_locs[k][:, 1:]\n        if class_coords is None:\n            continue\n        class_coords = np.asarray(class_coords)\n        if class_coords.size == 0:\n            continue\n\n        # Expect coords in (N, 3) as (z, y, x)\n        if class_coords.ndim != 2 or class_coords.shape[1] != 3:\n            raise ValueError(f\"class_locs[{k}] must have shape (N, 3), got {class_coords.shape}\")\n\n        z = class_coords[:, 0].astype(np.int64)\n        y = class_coords[:, 1].astype(np.int64)\n        x = class_coords[:, 2].astype(np.int64)\n\n        # Optional bounds check (cheap and prevents hard-to-debug indexing errors)\n        if (z.min() < 0 or y.min() < 0 or x.min() < 0 or\n                z.max() >= shape[0] or y.max() >= shape[1] or x.max() >= shape[2]):\n            raise ValueError(f\"Coordinates for {k} are out of bounds for shape={shape}\")\n\n        out[z, y, x] = i + 1  # label 1..K\n\n    img = sitk.GetImageFromArray(out)  # SimpleITK assumes array is z,y,x\n    sitk.WriteImage(img, outfile)\n\n\nif __name__ == '__main__':\n    # example_test_case_preprocessing()\n    # pp = DefaultPreprocessor()\n    # pp.run(2, '2d', 'nnUNetPlans', 8)\n\n    ###########################################################################################################\n    # how to process a test cases? This is an example:\n    # example_test_case_preprocessing()\n    seg = SimpleITK.GetArrayFromImage(SimpleITK.ReadImage('/home/isensee/temp/H-mito-val-v2.nii.gz'))[None]\n    a = DefaultPreprocessor._sample_foreground_locations(seg, np.arange(1, np.max(seg) + 1), min_percent_coverage=0.50)\n\n    _verify_class_locations(seg.shape[1:], '/home/isensee/temp/deleteme.nii.gz', a)\n"
  },
  {
    "path": "nnunetv2/preprocessing/resampling/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/preprocessing/resampling/default_resampling.py",
    "content": "from collections import OrderedDict\nfrom copy import deepcopy\nfrom typing import Union, Tuple, List\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom batchgenerators.augmentations.utils import resize_segmentation\nfrom scipy.ndimage import map_coordinates\nfrom skimage.transform import resize\nfrom nnunetv2.configuration import ANISO_THRESHOLD\n\n\ndef get_do_separate_z(spacing: Union[Tuple[float, ...], List[float], np.ndarray], anisotropy_threshold=ANISO_THRESHOLD):\n    do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold\n    return do_separate_z\n\n\ndef get_lowres_axis(new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]):\n    axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0]  # find which axis is anisotropic\n    return axis\n\n\ndef compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],\n                      old_spacing: Union[Tuple[float, ...], List[float], np.ndarray],\n                      new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray:\n    assert len(old_spacing) == len(old_shape)\n    assert len(old_shape) == len(new_spacing)\n    new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)])\n    return new_shape\n\n\ndef determine_do_sep_z_and_axis(\n        force_separate_z: bool,\n        current_spacing,\n        new_spacing,\n        separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, Union[int, None]]:\n    if force_separate_z is not None:\n        do_separate_z = force_separate_z\n        if force_separate_z:\n            axis = get_lowres_axis(current_spacing)\n        else:\n            axis = None\n    else:\n        if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold):\n            do_separate_z = True\n            axis = get_lowres_axis(current_spacing)\n        elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold):\n            do_separate_z = True\n            axis = get_lowres_axis(new_spacing)\n        else:\n            do_separate_z = False\n            axis = None\n\n    if axis is not None:\n        if len(axis) == 3:\n            do_separate_z = False\n            axis = None\n        elif len(axis) == 2:\n            # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample\n            # separately in the out of plane axis\n            do_separate_z = False\n            axis = None\n        else:\n            axis = axis[0]\n    return do_separate_z, axis\n\n\ndef resample_data_or_seg_to_spacing(data: np.ndarray,\n                                    current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],\n                                    new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],\n                                    is_seg: bool = False,\n                                    order: int = 3, order_z: int = 0,\n                                    force_separate_z: Union[bool, None] = False,\n                                    separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):\n    do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,\n                                                      separate_z_anisotropy_threshold)\n\n    if data is not None:\n        assert data.ndim == 4, \"data must be c x y z\"\n\n    shape = np.array(data.shape)\n    new_shape = compute_new_shape(shape[1:], current_spacing, new_spacing)\n\n    data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)\n    return data_reshaped\n\n\ndef resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],\n                                  new_shape: Union[Tuple[int, ...], List[int], np.ndarray],\n                                  current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],\n                                  new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],\n                                  is_seg: bool = False,\n                                  order: int = 3, order_z: int = 0,\n                                  force_separate_z: Union[bool, None] = False,\n                                  separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):\n    \"\"\"\n    needed for segmentation export. Stupid, I know\n    \"\"\"\n    if isinstance(data, torch.Tensor):\n        data = data.numpy()\n\n    do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,\n                                                      separate_z_anisotropy_threshold)\n\n    if data is not None:\n        assert data.ndim == 4, \"data must be c x y z\"\n\n    data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)\n    return data_reshaped\n\n\ndef resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray],\n                         is_seg: bool = False, axis: Union[None, int] = None, order: int = 3,\n                         do_separate_z: bool = False, order_z: int = 0, dtype_out = None):\n    \"\"\"\n    separate_z=True will resample with order 0 along z\n    :param data:\n    :param new_shape:\n    :param is_seg:\n    :param axis:\n    :param order:\n    :param do_separate_z:\n    :param order_z: only applies if do_separate_z is True\n    :return:\n    \"\"\"\n    assert data.ndim == 4, \"data must be (c, x, y, z)\"\n    assert len(new_shape) == data.ndim - 1\n\n    if is_seg:\n        resize_fn = resize_segmentation\n        kwargs = OrderedDict()\n    else:\n        resize_fn = resize\n        kwargs = {'mode': 'edge', 'anti_aliasing': False}\n    shape = np.array(data[0].shape)\n    new_shape = np.array(new_shape)\n    if dtype_out is None:\n        dtype_out = data.dtype\n    reshaped_final = np.zeros((data.shape[0], *new_shape), dtype=dtype_out)\n    if np.any(shape != new_shape):\n        data = data.astype(float, copy=False)\n        if do_separate_z:\n            assert axis is not None, 'If do_separate_z, we need to know what axis is anisotropic'\n            if axis == 0:\n                new_shape_2d = new_shape[1:]\n            elif axis == 1:\n                new_shape_2d = new_shape[[0, 2]]\n            else:\n                new_shape_2d = new_shape[:-1]\n\n            for c in range(data.shape[0]):\n                tmp = deepcopy(new_shape)\n                tmp[axis] = shape[axis]\n                reshaped_here = np.zeros(tmp)\n                for slice_id in range(shape[axis]):\n                    if axis == 0:\n                        reshaped_here[slice_id] = resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs)\n                    elif axis == 1:\n                        reshaped_here[:, slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs)\n                    else:\n                        reshaped_here[:, :, slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs)\n                if shape[axis] != new_shape[axis]:\n\n                    # The following few lines are blatantly copied and modified from sklearn's resize()\n                    rows, cols, dim = new_shape[0], new_shape[1], new_shape[2]\n                    orig_rows, orig_cols, orig_dim = reshaped_here.shape\n\n                    # align_corners=False\n                    row_scale = float(orig_rows) / rows\n                    col_scale = float(orig_cols) / cols\n                    dim_scale = float(orig_dim) / dim\n\n                    map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim]\n                    map_rows = row_scale * (map_rows + 0.5) - 0.5\n                    map_cols = col_scale * (map_cols + 0.5) - 0.5\n                    map_dims = dim_scale * (map_dims + 0.5) - 0.5\n\n                    coord_map = np.array([map_rows, map_cols, map_dims])\n                    if not is_seg or order_z == 0:\n                        reshaped_final[c] = map_coordinates(reshaped_here, coord_map, order=order_z, mode='nearest')[None]\n                    else:\n                        unique_labels = np.sort(pd.unique(reshaped_here.ravel()))  # np.unique(reshaped_data)\n                        for i, cl in enumerate(unique_labels):\n                            reshaped_final[c][np.round(\n                                map_coordinates((reshaped_here == cl).astype(float), coord_map, order=order_z,\n                                                mode='nearest')) > 0.5] = cl\n                else:\n                    reshaped_final[c] = reshaped_here\n        else:\n            for c in range(data.shape[0]):\n                reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs)\n        return reshaped_final\n    else:\n        # print(\"no resampling necessary\")\n        return data\n    \n\nif __name__ == '__main__':\n    input_array = np.random.random((1, 42, 231, 142))\n    output_shape = (52, 256, 256)\n    out = resample_data_or_seg(input_array, output_shape, is_seg=False, axis=3, order=1, order_z=0, do_separate_z=True)\n    print(out.shape, input_array.shape)\n"
  },
  {
    "path": "nnunetv2/preprocessing/resampling/no_resampling.py",
    "content": "from typing import Union, Tuple, List\n\nimport numpy as np\nimport torch\n\n\ndef no_resampling_hack(\n        data: Union[torch.Tensor, np.ndarray],\n        new_shape: Union[Tuple[int, ...], List[int], np.ndarray],\n        current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],\n        new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]\n):\n    return data"
  },
  {
    "path": "nnunetv2/preprocessing/resampling/resample_torch.py",
    "content": "from copy import deepcopy\nfrom typing import Union, Tuple, List\n\nimport numpy as np\nimport torch\nfrom einops import rearrange\nfrom torch.nn import functional as F\n\nfrom nnunetv2.configuration import ANISO_THRESHOLD\nfrom nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO\nfrom nnunetv2.preprocessing.resampling.default_resampling import determine_do_sep_z_and_axis\n\n\ndef resample_torch_simple(\n        data: Union[torch.Tensor, np.ndarray],\n        new_shape: Union[Tuple[int, ...], List[int], np.ndarray],\n        is_seg: bool = False,\n        num_threads: int = 4,\n        device: torch.device = torch.device('cpu'),\n        memefficient_seg_resampling: bool = False,\n        mode='linear'\n):\n    if mode == 'linear':\n        if data.ndim == 4:\n            torch_mode = 'trilinear'\n        elif data.ndim == 3:\n            torch_mode = 'bilinear'\n        else:\n            raise RuntimeError\n    else:\n        torch_mode = mode\n\n    if isinstance(new_shape, np.ndarray):\n        new_shape = [int(i) for i in new_shape]\n\n    if all([i == j for i, j in zip(new_shape, data.shape[1:])]):\n        return data\n    else:\n        n_threads = torch.get_num_threads()\n        torch.set_num_threads(num_threads)\n        new_shape = tuple(new_shape)\n        with torch.no_grad():\n\n            input_was_numpy = isinstance(data, np.ndarray)\n            if input_was_numpy:\n                data = torch.from_numpy(data).to(device)\n            else:\n                orig_device = deepcopy(data.device)\n                data = data.to(device)\n\n            if is_seg:\n                unique_values = torch.unique(data)\n                result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16\n                result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device)\n                if not memefficient_seg_resampling:\n                    # believe it or not, the implementation below is 3x as fast (at least on Liver CT and on CPU)\n                    # Why? Because argmax is slow. The implementation below immediately sets most locations and only lets the\n                    # uncertain ones be determined by argmax\n\n                    # unique_values = torch.unique(data)\n                    # result = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16)\n                    # for i, u in enumerate(unique_values):\n                    #     result[i] = F.interpolate((data[None] == u).float() * 1000, new_shape, mode='trilinear', antialias=False)[0]\n                    # result = unique_values[result.argmax(0)]\n\n                    result_tmp = torch.empty((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16,\n                                             device=device)\n                    scale_factor = 1000\n                    done_mask = torch.zeros_like(result, dtype=torch.bool, device=device)\n                    for i, u in enumerate(unique_values):\n                        result_tmp[i] = \\\n                            F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode,\n                                          antialias=False)[0]\n                        mask = result_tmp[i] > (0.7 * scale_factor)\n                        result[mask] = u.item()\n                        done_mask |= mask\n                    if not torch.all(done_mask):\n                        # print('resolving argmax', torch.sum(~done_mask), \"voxels to go\")\n                        result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype)\n                else:\n                    for i, u in enumerate(unique_values):\n                        if u == 0:\n                            pass\n                        result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[\n                                   0] > 0.5] = u\n            else:\n                result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0]\n            if input_was_numpy:\n                result = result.cpu().numpy()\n            else:\n                result = result.to(orig_device)\n        torch.set_num_threads(n_threads)\n        return result\n\n\ndef resample_torch_fornnunet(\n        data: Union[torch.Tensor, np.ndarray],\n        new_shape: Union[Tuple[int, ...], List[int], np.ndarray],\n        current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],\n        new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],\n        is_seg: bool = False,\n        num_threads: int = 4,\n        device: torch.device = torch.device('cpu'),\n        memefficient_seg_resampling: bool = False,\n        force_separate_z: Union[bool, None] = None,\n        separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,\n        mode='linear',\n        aniso_axis_mode='nearest-exact'\n):\n    \"\"\"\n    data must be c, x, y, z\n    \"\"\"\n    assert data.ndim == 4, \"data must be c, x, y, z\"\n    new_shape = [int(i) for i in new_shape]\n    orig_shape = data.shape\n\n    do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,\n                                                      separate_z_anisotropy_threshold)\n    if not isinstance(axis, (tuple, list)):\n        axis = (axis,)\n    # print('shape', data.shape, 'current_spacing', current_spacing, 'new_spacing', new_spacing, 'do_separate_z', do_separate_z, 'axis', axis)\n\n    if do_separate_z:\n        was_numpy = isinstance(data, np.ndarray)\n        if was_numpy:\n            data = torch.from_numpy(data)\n\n        assert len(axis) == 1\n        axis = axis[0]\n        tmp = \"xyz\"\n        axis_letter = tmp[axis]\n        others_int = [i for i in range(3) if i != axis]\n        others = [tmp[i] for i in others_int]\n\n        # reshape by overloading c channel\n        data = rearrange(data, f\"c x y z -> (c {axis_letter}) {others[0]} {others[1]}\")\n\n        # reshape in-plane\n        tmp_new_shape = [new_shape[i] for i in others_int]\n        data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device,\n                                     memefficient_seg_resampling=memefficient_seg_resampling, mode=mode)\n        data = rearrange(data, f\"(c {axis_letter}) {others[0]} {others[1]} -> c x y z\",\n                         **{\n                             axis_letter: orig_shape[axis + 1],\n                             others[0]: tmp_new_shape[0],\n                             others[1]: tmp_new_shape[1]\n                         }\n                         )\n        # reshape out of plane w/ nearest\n        data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device,\n                                     memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode)\n        if was_numpy:\n            data = data.numpy()\n        return data\n    else:\n        return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling)\n\n\nif __name__ == '__main__':\n    torch.set_num_threads(16)\n    img_file = '/media/isensee/raw_data/nnUNet_raw/Dataset027_ACDC/imagesTr/patient041_frame01_0000.nii.gz'\n    seg_file = '/media/isensee/raw_data/nnUNet_raw/Dataset027_ACDC/labelsTr/patient041_frame01.nii.gz'\n    io = SimpleITKIO()\n    data, pkl = io.read_images((img_file, ))\n    seg, pkl = io.read_seg(seg_file)\n\n    target_shape = (15, 256, 312)\n    spacing = pkl['spacing']\n\n    use = data\n    is_seg = False\n\n    ret_nosep = resample_torch_fornnunet(use, target_shape, spacing, spacing, is_seg)\n    ret_sep = resample_torch_fornnunet(use, target_shape, spacing, spacing, is_seg, force_separate_z=False)\n\n"
  },
  {
    "path": "nnunetv2/preprocessing/resampling/utils.py",
    "content": "from typing import Callable\n\nimport nnunetv2\nfrom batchgenerators.utilities.file_and_folder_operations import join\nfrom nnunetv2.utilities.find_class_by_name import recursive_find_python_class\n\n\ndef recursive_find_resampling_fn_by_name(resampling_fn: str) -> Callable:\n    ret = recursive_find_python_class(join(nnunetv2.__path__[0], \"preprocessing\", \"resampling\"), resampling_fn,\n                                      'nnunetv2.preprocessing.resampling')\n    if ret is None:\n        raise RuntimeError(\"Unable to find resampling function named '%s'. Please make sure this fn is located in the \"\n                           \"nnunetv2.preprocessing.resampling module.\" % resampling_fn)\n    else:\n        return ret\n"
  },
  {
    "path": "nnunetv2/run/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/run/load_pretrained_weights.py",
    "content": "import torch\nfrom torch._dynamo import OptimizedModule\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nimport torch.distributed as dist\n\n\ndef load_pretrained_weights(network, fname, verbose=False):\n    \"\"\"\n    Transfers all weights between matching keys in state_dicts. matching is done by name and we only transfer if the\n    shape is also the same. Segmentation layers (the 1x1(x1) layers that produce the segmentation maps)\n    identified by keys ending with '.seg_layers') are not transferred!\n\n    If the pretrained weights were obtained with a training outside nnU-Net and DDP or torch.optimize was used,\n    you need to change the keys of the pretrained state_dict. DDP adds a 'module.' prefix and torch.optim adds\n    '_orig_mod'. You DO NOT need to worry about this if pretraining was done with nnU-Net as\n    nnUNetTrainer.save_checkpoint takes care of that!\n\n    \"\"\"\n    if dist.is_initialized():\n        saved_model = torch.load(fname, map_location=torch.device('cuda', dist.get_rank()), weights_only=False)\n    else:\n        saved_model = torch.load(fname, weights_only=False)\n    pretrained_dict = saved_model['network_weights']\n\n    skip_strings_in_pretrained = [\n        '.seg_layers.',\n    ]\n\n    if isinstance(network, DDP):\n        mod = network.module\n    else:\n        mod = network\n    if isinstance(mod, OptimizedModule):\n        mod = mod._orig_mod\n\n    model_dict = mod.state_dict()\n    # verify that all but the segmentation layers have the same shape\n    for key, _ in model_dict.items():\n        if all([i not in key for i in skip_strings_in_pretrained]):\n            assert key in pretrained_dict, \\\n                f\"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be \" \\\n                f\"compatible with your network.\"\n            assert model_dict[key].shape == pretrained_dict[key].shape, \\\n                f\"The shape of the parameters of key {key} is not the same. Pretrained model: \" \\\n                f\"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model \" \\\n                f\"does not seem to be compatible with your network.\"\n\n    # fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained\n    # encoders. Not supported by this function though (see assertions above)\n\n    # commenting out this abomination of a dict comprehension for preservation in the archives of 'what not to do'\n    # pretrained_dict = {'module.' + k if is_ddp else k: v\n    #                    for k, v in pretrained_dict.items()\n    #                    if (('module.' + k if is_ddp else k) in model_dict) and\n    #                    all([i not in k for i in skip_strings_in_pretrained])}\n\n    pretrained_dict = {k: v for k, v in pretrained_dict.items()\n                       if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])}\n\n    model_dict.update(pretrained_dict)\n\n    print(\"################### Loading pretrained weights from file \", fname, '###################')\n    if verbose:\n        print(\"Below is the list of overlapping blocks in pretrained model and nnUNet architecture:\")\n        for key, value in pretrained_dict.items():\n            print(key, 'shape', value.shape)\n        print(\"################### Done ###################\")\n    mod.load_state_dict(model_dict)\n\n\n"
  },
  {
    "path": "nnunetv2/run/run_training.py",
    "content": "import multiprocessing\nimport os\nimport socket\nfrom typing import Union, Optional\n\nimport nnunetv2\nimport torch.cuda\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nfrom batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json\nfrom nnunetv2.paths import nnUNet_preprocessed\nfrom nnunetv2.run.load_pretrained_weights import load_pretrained_weights\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\nfrom nnunetv2.utilities.find_class_by_name import recursive_find_python_class\nfrom torch.backends import cudnn\n\n\ndef find_free_network_port() -> int:\n    \"\"\"Finds a free port on localhost.\n\n    It is useful in single-node training when we don't want to connect to a real main node but have to set the\n    `MASTER_PORT` environment variable.\n    \"\"\"\n    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n    s.bind((\"\", 0))\n    port = s.getsockname()[1]\n    s.close()\n    return port\n\n\ndef get_trainer_from_args(dataset_name_or_id: Union[int, str],\n                          configuration: str,\n                          fold: int,\n                          trainer_name: str = 'nnUNetTrainer',\n                          plans_identifier: str = 'nnUNetPlans',\n                          continue_training: bool = False,\n                          device: torch.device = torch.device('cuda')):\n    # load nnunet class and do sanity checks\n    nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], \"training\", \"nnUNetTrainer\"),\n                                                trainer_name, 'nnunetv2.training.nnUNetTrainer')\n    if nnunet_trainer is None:\n        raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in '\n                           f'nnunetv2.training.nnUNetTrainer ('\n                           f'{join(nnunetv2.__path__[0], \"training\", \"nnUNetTrainer\")}). If it is located somewhere '\n                           f'else, please move it there.')\n    assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \\\n                                                    'nnUNetTrainer'\n\n    # handle dataset input. If it's an ID we need to convert to int from string\n    if dataset_name_or_id.startswith('Dataset'):\n        pass\n    else:\n        try:\n            dataset_name_or_id = int(dataset_name_or_id)\n        except ValueError:\n            raise ValueError(f'dataset_name_or_id must either be an integer or a valid dataset name with the pattern '\n                             f'DatasetXXX_YYY where XXX are the three(!) task ID digits. Your '\n                             f'input: {dataset_name_or_id}')\n\n    # initialize nnunet trainer\n    preprocessed_dataset_folder_base = join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id))\n    plans_file = join(preprocessed_dataset_folder_base, plans_identifier + '.json')\n    plans = load_json(plans_file)\n    plans[\"continue_training\"] = continue_training\n    dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json'))\n    nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold,\n                                    dataset_json=dataset_json, device=device)\n    return nnunet_trainer\n\n\ndef maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool, validation_only: bool,\n                          pretrained_weights_file: str = None):\n    if continue_training and pretrained_weights_file is not None:\n        raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only '\n                           'be used at the beginning of the training.')\n    if continue_training:\n        expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')\n        if not isfile(expected_checkpoint_file):\n            expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth')\n        # special case where --c is used to run a previously aborted validation\n        if not isfile(expected_checkpoint_file):\n            expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth')\n        if not isfile(expected_checkpoint_file):\n            print(f\"WARNING: Cannot continue training because there seems to be no checkpoint available to \"\n                               f\"continue from. Starting a new training...\")\n            expected_checkpoint_file = None\n    elif validation_only:\n        expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')\n        if not isfile(expected_checkpoint_file):\n            raise RuntimeError(f\"Cannot run validation because the training is not finished yet!\")\n    else:\n        if pretrained_weights_file is not None:\n            if not nnunet_trainer.was_initialized:\n                nnunet_trainer.initialize()\n            load_pretrained_weights(nnunet_trainer.network, pretrained_weights_file, verbose=True)\n        expected_checkpoint_file = None\n\n    if expected_checkpoint_file is not None:\n        nnunet_trainer.load_checkpoint(expected_checkpoint_file)\n\n\ndef setup_ddp(rank, world_size):\n    # initialize the process group\n    dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n\n\ndef cleanup_ddp():\n    dist.destroy_process_group()\n\n\ndef run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, disable_checkpointing, c, val,\n            pretrained_weights, npz, val_with_best, world_size):\n    setup_ddp(rank, world_size)\n    torch.cuda.set_device(torch.device('cuda', dist.get_rank()))\n\n    nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p, c)\n\n    if disable_checkpointing:\n        nnunet_trainer.disable_checkpointing = disable_checkpointing\n\n    assert not (c and val), f'Cannot set --c and --val flag at the same time. Dummy.'\n\n    maybe_load_checkpoint(nnunet_trainer, c, val, pretrained_weights)\n\n    if torch.cuda.is_available():\n        cudnn.deterministic = False\n        cudnn.benchmark = True\n\n    if not val:\n        nnunet_trainer.run_training()\n\n    if val_with_best:\n        nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth'))\n    nnunet_trainer.perform_actual_validation(npz)\n    cleanup_ddp()\n\n\ndef run_training(dataset_name_or_id: Union[str, int],\n                 configuration: str, fold: Union[int, str],\n                 trainer_class_name: str = 'nnUNetTrainer',\n                 plans_identifier: str = 'nnUNetPlans',\n                 pretrained_weights: Optional[str] = None,\n                 num_gpus: int = 1,\n                 export_validation_probabilities: bool = False,\n                 continue_training: bool = False,\n                 only_run_validation: bool = False,\n                 disable_checkpointing: bool = False,\n                 val_with_best: bool = False,\n                 device: torch.device = torch.device('cuda')):\n    if plans_identifier == 'nnUNetPlans':\n        print(\"\\n############################\\n\"\n              \"INFO: You are using the old nnU-Net default plans. We have updated our recommendations. \"\n              \"Please consider using those instead! \"\n              \"Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md\"\n              \"\\n############################\\n\")\n    if isinstance(fold, str):\n        if fold != 'all':\n            try:\n                fold = int(fold)\n            except ValueError as e:\n                print(f'Unable to convert given value for fold to int: {fold}. fold must bei either \"all\" or an integer!')\n                raise e\n\n    if val_with_best:\n        assert not disable_checkpointing, '--val_best is not compatible with --disable_checkpointing'\n\n    if num_gpus > 1:\n        assert device.type == 'cuda', f\"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}\"\n\n        os.environ['MASTER_ADDR'] = 'localhost'\n        if 'MASTER_PORT' not in os.environ.keys():\n            port = str(find_free_network_port())\n            print(f\"using port {port}\")\n            os.environ['MASTER_PORT'] = port  # str(port)\n\n        mp.spawn(run_ddp,\n                 args=(\n                     dataset_name_or_id,\n                     configuration,\n                     fold,\n                     trainer_class_name,\n                     plans_identifier,\n                     disable_checkpointing,\n                     continue_training,\n                     only_run_validation,\n                     pretrained_weights,\n                     export_validation_probabilities,\n                     val_with_best,\n                     num_gpus),\n                 nprocs=num_gpus,\n                 join=True)\n    else:\n        nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name,\n                                               plans_identifier, continue_training, device=device)\n\n        if disable_checkpointing:\n            nnunet_trainer.disable_checkpointing = disable_checkpointing\n\n        assert not (continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'\n\n        maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)\n\n        if torch.cuda.is_available():\n            cudnn.deterministic = False\n            cudnn.benchmark = True\n\n        if not only_run_validation:\n            nnunet_trainer.run_training()\n\n        if val_with_best:\n            nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth'))\n        nnunet_trainer.perform_actual_validation(export_validation_probabilities)\n\n\ndef run_training_entry():\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('dataset_name_or_id', type=str,\n                        help=\"Dataset name or ID to train with\")\n    parser.add_argument('configuration', type=str,\n                        help=\"Configuration that should be trained\")\n    parser.add_argument('fold', type=str,\n                        help='Fold of the 5-fold cross-validation. Should be an int between 0 and 4.')\n    parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',\n                        help='[OPTIONAL] Use this flag to specify a custom trainer. Default: nnUNetTrainer')\n    parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',\n                        help='[OPTIONAL] Use this flag to specify a custom plans identifier. Default: nnUNetPlans')\n    parser.add_argument('-pretrained_weights', type=str, required=False, default=None,\n                        help='[OPTIONAL] path to nnU-Net checkpoint file to be used as pretrained model. Will only '\n                             'be used when actually training. Beta. Use with caution.')\n    parser.add_argument('-num_gpus', type=int, default=1, required=False,\n                        help='Specify the number of GPUs to use for training')\n    parser.add_argument('--npz', action='store_true', required=False,\n                        help='[OPTIONAL] Save softmax predictions from final validation as npz files (in addition to predicted '\n                             'segmentations). Needed for finding the best ensemble.')\n    parser.add_argument('--c', action='store_true', required=False,\n                        help='[OPTIONAL] Continue training from latest checkpoint')\n    parser.add_argument('--val', action='store_true', required=False,\n                        help='[OPTIONAL] Set this flag to only run the validation. Requires training to have finished.')\n    parser.add_argument('--val_best', action='store_true', required=False,\n                        help='[OPTIONAL] If set, the validation will be performed with the checkpoint_best instead '\n                             'of checkpoint_final. NOT COMPATIBLE with --disable_checkpointing! '\n                             'WARNING: This will use the same \\'validation\\' folder as the regular validation '\n                             'with no way of distinguishing the two!')\n    parser.add_argument('--disable_checkpointing', action='store_true', required=False,\n                        help='[OPTIONAL] Set this flag to disable checkpointing. Ideal for testing things out and '\n                             'you dont want to flood your hard drive with checkpoints.')\n    parser.add_argument('-device', type=str, default='cuda', required=False,\n                    help=\"Use this to set the device the training should run with. Available options are 'cuda' \"\n                         \"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! \"\n                         \"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!\")\n    args = parser.parse_args()\n\n    assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'\n    if args.device == 'cpu':\n        # let's allow torch to use hella threads\n        torch.set_num_threads(multiprocessing.cpu_count())\n        device = torch.device('cpu')\n    elif args.device == 'cuda':\n        # multithreading in torch doesn't help nnU-Net if run on GPU\n        torch.set_num_threads(1)\n        torch.set_num_interop_threads(1)\n        device = torch.device('cuda')\n    else:\n        device = torch.device('mps')\n\n    run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,\n                 args.num_gpus, args.npz, args.c, args.val, args.disable_checkpointing, args.val_best,\n                 device=device)\n\n\nif __name__ == '__main__':\n    os.environ['OMP_NUM_THREADS'] = '1'\n    os.environ['MKL_NUM_THREADS'] = '1'\n    os.environ['OPENBLAS_NUM_THREADS'] = '1'\n    # reduces the number of threads used for compiling. More threads don't help and can cause problems\n    os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'\n    # multiprocessing.set_start_method(\"spawn\")\n    run_training_entry()\n"
  },
  {
    "path": "nnunetv2/tests/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/tests/integration_tests/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/tests/integration_tests/add_lowres_and_cascade.py",
    "content": "from copy import deepcopy\n\nfrom batchgenerators.utilities.file_and_folder_operations import *\n\nfrom nnunetv2.paths import nnUNet_preprocessed\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-d', nargs='+', type=int, help='List of dataset ids')\n    args = parser.parse_args()\n\n    for d in args.d:\n        dataset_name = maybe_convert_to_dataset_name(d)\n        plans = load_json(join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json'))\n        plans['configurations']['3d_lowres'] = {\n            \"data_identifier\": \"nnUNetPlans_3d_lowres\",  # do not be a dumbo and forget this. I was a dumbo. And I paid dearly with ~10 min debugging time\n            'inherits_from': '3d_fullres',\n            \"patch_size\": [20, 28, 20],\n            \"median_image_size_in_voxels\": [18.0, 25.0, 18.0],\n            \"spacing\": [2.0, 2.0, 2.0],\n            \"architecture\": deepcopy(plans['configurations']['3d_fullres'][\"architecture\"]),\n            \"next_stage\": \"3d_cascade_fullres\"\n        }\n        plans['configurations']['3d_lowres']['architecture'][\"arch_kwargs\"]['n_conv_per_stage'] = [2, 2, 2]\n        plans['configurations']['3d_lowres']['architecture'][\"arch_kwargs\"]['n_conv_per_stage_decoder'] = [2, 2]\n        plans['configurations']['3d_lowres']['architecture'][\"arch_kwargs\"]['strides'] = [[1, 1, 1], [2, 2, 2], [2, 2, 2]]\n        plans['configurations']['3d_lowres']['architecture'][\"arch_kwargs\"]['kernel_sizes'] = [[3, 3, 3], [3, 3, 3], [3, 3, 3]]\n        plans['configurations']['3d_lowres']['architecture'][\"arch_kwargs\"]['n_stages'] = 3\n        plans['configurations']['3d_lowres']['architecture'][\"arch_kwargs\"]['features_per_stage'] = [\n            32,\n            64,\n            128\n        ]\n\n        plans['configurations']['3d_cascade_fullres'] = {\n            'inherits_from': '3d_fullres',\n            \"previous_stage\": \"3d_lowres\"\n        }\n        save_json(plans, join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json'), sort_keys=False)"
  },
  {
    "path": "nnunetv2/tests/integration_tests/cleanup_integration_test.py",
    "content": "import shutil\n\nfrom batchgenerators.utilities.file_and_folder_operations import isdir, join\n\nfrom nnunetv2.paths import nnUNet_raw, nnUNet_results, nnUNet_preprocessed\n\nif __name__ == '__main__':\n    # deletes everything!\n    dataset_names = [\n        'Dataset996_IntegrationTest_Hippocampus_regions_ignore',\n        'Dataset997_IntegrationTest_Hippocampus_regions',\n        'Dataset998_IntegrationTest_Hippocampus_ignore',\n        'Dataset999_IntegrationTest_Hippocampus',\n    ]\n    for fld in [nnUNet_raw, nnUNet_preprocessed, nnUNet_results]:\n        for d in dataset_names:\n            if isdir(join(fld, d)):\n                shutil.rmtree(join(fld, d))\n\n"
  },
  {
    "path": "nnunetv2/tests/integration_tests/lsf_commands.sh",
    "content": "bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash \". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 996\"\nbsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash \". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 997\"\nbsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash \". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 998\"\nbsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash \". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 999\"\n\n\nbsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash \". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 996\"\nbsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash \". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 997\"\nbsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash \". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 998\"\nbsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash \". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 999\"\n"
  },
  {
    "path": "nnunetv2/tests/integration_tests/prepare_integration_tests.sh",
    "content": "# assumes you are in the nnunet repo!\n\n# prepare raw datasets\npython nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py\npython nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py\npython nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py\npython nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py\n\n# now run experiment planning without preprocessing\nnnUNetv2_plan_and_preprocess -d 996 997 998 999 --no_pp\n\n# now add 3d lowres and cascade\npython nnunetv2/tests/integration_tests/add_lowres_and_cascade.py -d 996 997 998 999\n\n# now preprocess everything\nnnUNetv2_preprocess -d 996 997 998 999 -c 2d 3d_lowres 3d_fullres -np 8 8 8  # no need to preprocess cascade as its the same data as 3d_fullres\n\n# done"
  },
  {
    "path": "nnunetv2/tests/integration_tests/readme.md",
    "content": "# Preface\n\nI am just a mortal with many tasks and limited time. Aint nobody got time for unittests.\n\nHOWEVER, at least some integration tests should be performed testing nnU-Net from start to finish.\n\n# Introduction - What the heck is happening?\nThis test covers all possible labeling scenarios (standard labels, regions, ignore labels and regions with \nignore labels). It runs the entire nnU-Net pipeline from start to finish:\n\n- fingerprint extraction\n- experiment planning\n- preprocessing\n- train all 4 configurations (2d, 3d_lowres, 3d_fullres, 3d_cascade_fullres) as 5-fold CV\n- automatically find the best model or ensemble\n- determine the postprocessing used for this\n- predict some test set\n- apply postprocessing to the test set\n\nTo speed things up, we do the following:\n- pick Dataset004_Hippocampus because it is quadratisch praktisch gut. MNIST of medical image segmentation\n- by default this dataset does not have 3d_lowres or cascade. We just manually add them (cool new feature, eh?). See `add_lowres_and_cascade.py` to learn more! \n- we use nnUNetTrainer_5epochs for a short training\n\n# How to run it?\n\nSet your pwd to be the nnunet repo folder (the one where the `nnunetv2` folder and the `setup.py` are located!)\n\nNow generate the 4 dummy datasets (ids 996, 997, 998, 999) from dataset 4. This will crash if you don't have Dataset004!\n```commandline\nbash nnunetv2/tests/integration_tests/prepare_integration_tests.sh \n```\n\nNow you can run the integration test for each of the datasets:\n```commandline\nbash nnunetv2/tests/integration_tests/run_integration_test.sh DATSET_ID\n```\nuse DATSET_ID 996, 997, 998 and 999. You can run these independently on different GPUs/systems to speed things up. \nThis will take i dunno like 10-30 Minutes!?\n\nAlso run \n```commandline\nbash nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh DATSET_ID\n```\nto verify DDP is working (needs 2 GPUs!)\n\n# How to check if the test was successful?\nIf I was not as lazy as I am I would have programmed some automatism that checks if Dice scores etc are in an acceptable range.\nSo you need to do the following:\n1) check that none of your runs crashed (duh)\n2) for each run, navigate to `nnUNet_results/DATASET_NAME` and take a look at the `inference_information.json` file. \nDoes it make sense? If so: NICE!\n\nOnce the integration test is completed you can delete all the temporary files associated with it by running:\n\n```commandline\npython nnunetv2/tests/integration_tests/cleanup_integration_test.py\n```"
  },
  {
    "path": "nnunetv2/tests/integration_tests/run_integration_test.sh",
    "content": "\n\nnnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 3d_fullres 1 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 3d_fullres 2 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 3d_fullres 3 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 3d_fullres 4 -tr nnUNetTrainer_5epochs --npz\n\nnnUNetv2_train $1 2d 0 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 2d 1 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 2d 2 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 2d 3 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 2d 4 -tr nnUNetTrainer_5epochs --npz\n\nnnUNetv2_train $1 3d_lowres 0 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 3d_lowres 1 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 3d_lowres 2 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 3d_lowres 3 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 3d_lowres 4 -tr nnUNetTrainer_5epochs --npz\n\nnnUNetv2_train $1 3d_cascade_fullres 0 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 3d_cascade_fullres 1 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 3d_cascade_fullres 2 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 3d_cascade_fullres 3 -tr nnUNetTrainer_5epochs --npz\nnnUNetv2_train $1 3d_cascade_fullres 4 -tr nnUNetTrainer_5epochs --npz\n\npython nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py -d $1"
  },
  {
    "path": "nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py",
    "content": "import argparse\n\nimport torch\nfrom batchgenerators.utilities.file_and_folder_operations import join, load_pickle\n\nfrom nnunetv2.ensembling.ensemble import ensemble_folders\nfrom nnunetv2.evaluation.find_best_configuration import find_best_configuration, \\\n    dumb_trainer_config_plans_to_trained_models_dict\nfrom nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\nfrom nnunetv2.paths import nnUNet_raw, nnUNet_results\nfrom nnunetv2.postprocessing.remove_connected_components import apply_postprocessing_to_folder\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\nfrom nnunetv2.utilities.file_path_utilities import get_output_folder\n\n\nif __name__ == '__main__':\n    \"\"\"\n    Predicts the imagesTs folder with the best configuration and applies postprocessing\n    \"\"\"\n    torch.set_num_threads(1)\n    torch.set_num_interop_threads(1)\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-d', type=int, help='dataset id')\n    args = parser.parse_args()\n    d = args.d\n\n    dataset_name = maybe_convert_to_dataset_name(d)\n    source_dir = join(nnUNet_raw, dataset_name, 'imagesTs')\n    target_dir_base = join(nnUNet_results, dataset_name)\n\n    models = dumb_trainer_config_plans_to_trained_models_dict(['nnUNetTrainer_5epochs'],\n                                                              ['2d',\n                                                               '3d_lowres',\n                                                               '3d_cascade_fullres',\n                                                               '3d_fullres'],\n                                                              ['nnUNetPlans'])\n    ret = find_best_configuration(d, models, allow_ensembling=True, num_processes=8, overwrite=True,\n                                  folds=(0, 1, 2, 3, 4), strict=True)\n\n    has_ensemble = len(ret['best_model_or_ensemble']['selected_model_or_models']) > 1\n\n    # we don't use all folds to speed stuff up\n    used_folds = (0, 3)\n    output_folders = []\n    for im in ret['best_model_or_ensemble']['selected_model_or_models']:\n        output_dir = join(target_dir_base, f\"pred_{im['configuration']}\")\n        model_folder = get_output_folder(d, im['trainer'], im['plans_identifier'], im['configuration'])\n        # note that if the best model is the enseble of 3d_lowres and 3d cascade then 3d_lowres will be predicted\n        # twice (once standalone and once to generate the predictions for the cascade) because we don't reuse the\n        # prediction here. Proper way would be to check for that and\n        # then give the output of 3d_lowres inference to the folder_with_segs_from_prev_stage kwarg in\n        # predict_from_raw_data. Since we allow for\n        # dynamically setting 'previous_stage' in the plans I am too lazy to implement this here. This is just an\n        # integration test after all. Take a closer look at how this in handled in predict_from_raw_data\n        predictor = nnUNetPredictor(verbose=False, allow_tqdm=False)\n        predictor.initialize_from_trained_model_folder(model_folder, used_folds)\n        predictor.predict_from_files(source_dir, output_dir, has_ensemble, overwrite=True)\n        # predict_from_raw_data(list_of_lists_or_source_folder=source_dir, output_folder=output_dir,\n        #                       model_training_output_dir=model_folder, use_folds=used_folds,\n        #                       save_probabilities=has_ensemble, verbose=False, overwrite=True)\n        output_folders.append(output_dir)\n\n    # if we have an ensemble, we need to ensemble the results\n    if has_ensemble:\n        ensemble_folders(output_folders, join(target_dir_base, 'ensemble_predictions'), save_merged_probabilities=False)\n        folder_for_pp = join(target_dir_base, 'ensemble_predictions')\n    else:\n        folder_for_pp = output_folders[0]\n\n    # apply postprocessing\n    pp_fns, pp_fn_kwargs = load_pickle(ret['best_model_or_ensemble']['postprocessing_file'])\n    apply_postprocessing_to_folder(folder_for_pp, join(target_dir_base, 'ensemble_predictions_postprocessed'),\n                                   pp_fns,\n                                   pp_fn_kwargs, plans_file_or_dict=ret['best_model_or_ensemble']['some_plans_file'])\n"
  },
  {
    "path": "nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh",
    "content": "nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_10epochs -num_gpus 2\n"
  },
  {
    "path": "nnunetv2/tests/integration_tests/run_nnunet_inference.py",
    "content": "import os\nimport shutil\nimport subprocess\nfrom pathlib import Path\n\nimport nibabel as nib\nimport numpy as np\n\n\ndef dice_score(y_true, y_pred):\n    intersect = np.sum(y_true * y_pred)\n    denominator = np.sum(y_true) + np.sum(y_pred)\n    f1 = (2 * intersect) / (denominator + 1e-6)\n    return f1\n\n\ndef run_tests_and_exit_on_failure():\n    \"\"\"\n    Runs inference of a simple nnU-Net for CT body segmentation on a small example CT image \n    and checks if the output is correct.\n    \"\"\"\n    # Set nnUNet_results env var\n    weights_dir = Path.home() / \"github_actions_nnunet\" / \"results\"\n    os.environ[\"nnUNet_results\"] = str(weights_dir)\n\n    # Copy example file\n    os.makedirs(\"nnunetv2/tests/github_actions_output\", exist_ok=True)\n    shutil.copy(\"nnunetv2/tests/example_data/example_ct_sm.nii.gz\", \"nnunetv2/tests/github_actions_output/example_ct_sm_0000.nii.gz\")\n\n    # Run nnunet\n    subprocess.call(f\"nnUNetv2_predict -i nnunetv2/tests/github_actions_output -o nnunetv2/tests/github_actions_output -d 300 -tr nnUNetTrainer -c 3d_fullres -f 0 -device cpu\", shell=True)\n\n    # Check if the nnunet segmentation is correct\n    img_gt = nib.load(f\"nnunetv2/tests/example_data/example_ct_sm_T300_output.nii.gz\").get_fdata()\n    img_pred = nib.load(f\"nnunetv2/tests/github_actions_output/example_ct_sm.nii.gz\").get_fdata()\n    dice = dice_score(img_gt, img_pred)\n    images_equal = dice > 0.99  # allow for a small difference in the segmentation, otherwise the test will fail often\n    assert images_equal, f\"The nnunet segmentation is not correct (dice: {dice:.5f}).\"\n\n    # Clean up\n    shutil.rmtree(\"nnunetv2/tests/github_actions_output\")\n    shutil.rmtree(Path.home() / \"github_actions_nnunet\")\n\n\nif __name__ == \"__main__\":\n    run_tests_and_exit_on_failure()"
  },
  {
    "path": "nnunetv2/training/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/data_augmentation/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/data_augmentation/compute_initial_patch_size.py",
    "content": "import numpy as np\n\n\ndef get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range):\n    if isinstance(rot_x, (tuple, list)):\n        rot_x = max(np.abs(rot_x))\n    if isinstance(rot_y, (tuple, list)):\n        rot_y = max(np.abs(rot_y))\n    if isinstance(rot_z, (tuple, list)):\n        rot_z = max(np.abs(rot_z))\n    rot_x = min(90 / 360 * 2. * np.pi, rot_x)\n    rot_y = min(90 / 360 * 2. * np.pi, rot_y)\n    rot_z = min(90 / 360 * 2. * np.pi, rot_z)\n    from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d\n    coords = np.array(final_patch_size)\n    final_shape = np.copy(coords)\n    if len(coords) == 3:\n        final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0)\n        final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0)\n        final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0)\n    elif len(coords) == 2:\n        final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0)\n    final_shape /= min(scale_range)\n    return final_shape.astype(int)\n"
  },
  {
    "path": "nnunetv2/training/data_augmentation/custom_transforms/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/data_augmentation/custom_transforms/cascade_transforms.py",
    "content": "from typing import Union, List, Tuple, Callable\n\nimport numpy as np\nfrom acvl_utils.morphology.morphology_helper import label_with_component_sizes\nfrom batchgenerators.transforms.abstract_transforms import AbstractTransform\nfrom skimage.morphology import ball\nfrom skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening\n\n\nclass MoveSegAsOneHotToData(AbstractTransform):\n    def __init__(self, index_in_origin: int, all_labels: Union[Tuple[int, ...], List[int]],\n                 key_origin=\"seg\", key_target=\"data\", remove_from_origin=True):\n        \"\"\"\n        Takes data_dict[seg][:, index_in_origin], converts it to one hot encoding and appends it to\n        data_dict[key_target]. Optionally removes index_in_origin from data_dict[seg].\n        \"\"\"\n        self.remove_from_origin = remove_from_origin\n        self.all_labels = all_labels\n        self.key_target = key_target\n        self.key_origin = key_origin\n        self.index_in_origin = index_in_origin\n\n    def __call__(self, **data_dict):\n        seg = data_dict[self.key_origin][:, self.index_in_origin:self.index_in_origin+1]\n\n        seg_onehot = np.zeros((seg.shape[0], len(self.all_labels), *seg.shape[2:]),\n                              dtype=data_dict[self.key_target].dtype)\n        for i, l in enumerate(self.all_labels):\n            seg_onehot[:, i][seg[:, 0] == l] = 1\n\n        data_dict[self.key_target] = np.concatenate((data_dict[self.key_target], seg_onehot), 1)\n\n        if self.remove_from_origin:\n            remaining_channels = [i for i in range(data_dict[self.key_origin].shape[1]) if i != self.index_in_origin]\n            data_dict[self.key_origin] = data_dict[self.key_origin][:, remaining_channels]\n\n        return data_dict\n\n\nclass RemoveRandomConnectedComponentFromOneHotEncodingTransform(AbstractTransform):\n    def __init__(self, channel_idx: Union[int, List[int]], key: str = \"data\", p_per_sample: float = 0.2,\n                 fill_with_other_class_p: float = 0.25,\n                 dont_do_if_covers_more_than_x_percent: float = 0.25, p_per_label: float = 1):\n        \"\"\"\n        Randomly removes connected components in the specified channel_idx of data_dict[key]. Only considers components\n        smaller than dont_do_if_covers_more_than_X_percent of the sample. Also has the option of simulating\n        misclassification as another class (fill_with_other_class_p)\n        \"\"\"\n        self.p_per_label = p_per_label\n        self.dont_do_if_covers_more_than_x_percent = dont_do_if_covers_more_than_x_percent\n        self.fill_with_other_class_p = fill_with_other_class_p\n        self.p_per_sample = p_per_sample\n        self.key = key\n        if not isinstance(channel_idx, (list, tuple)):\n            channel_idx = [channel_idx]\n        self.channel_idx = channel_idx\n\n    def __call__(self, **data_dict):\n        data = data_dict.get(self.key)\n        for b in range(data.shape[0]):\n            if np.random.uniform() < self.p_per_sample:\n                for c in self.channel_idx:\n                    if np.random.uniform() < self.p_per_label:\n                        # print(np.unique(data[b, c])) ## should be [0, 1]\n                        workon = data[b, c].astype(bool)\n                        if not np.any(workon):\n                            continue\n                        num_voxels = np.prod(workon.shape, dtype=np.uint64)\n                        lab, component_sizes = label_with_component_sizes(workon.astype(bool))\n                        if len(component_sizes) > 0:\n                            valid_component_ids = [i for i, j in component_sizes.items() if j <\n                                                   num_voxels*self.dont_do_if_covers_more_than_x_percent]\n                            # print('RemoveRandomConnectedComponentFromOneHotEncodingTransform', c,\n                            # np.unique(data[b, c]), len(component_sizes), valid_component_ids,\n                            # len(valid_component_ids))\n                            if len(valid_component_ids) > 0:\n                                random_component = np.random.choice(valid_component_ids)\n                                data[b, c][lab == random_component] = 0\n                                if np.random.uniform() < self.fill_with_other_class_p:\n                                    other_ch = [i for i in self.channel_idx if i != c]\n                                    if len(other_ch) > 0:\n                                        other_class = np.random.choice(other_ch)\n                                        data[b, other_class][lab == random_component] = 1\n        data_dict[self.key] = data\n        return data_dict\n\n\nclass ApplyRandomBinaryOperatorTransform(AbstractTransform):\n    def __init__(self,\n                 channel_idx: Union[int, List[int], Tuple[int, ...]],\n                 p_per_sample: float = 0.3,\n                 any_of_these: Tuple[Callable] = (binary_dilation, binary_erosion, binary_closing, binary_opening),\n                 key: str = \"data\",\n                 strel_size: Tuple[int, int] = (1, 10),\n                 p_per_label: float = 1):\n        \"\"\"\n        Applies random binary operations (specified by any_of_these) with random ball size (radius is uniformly sampled\n        from interval strel_size) to specified channels. Expects the channel_idx to correspond to a hone hot encoded\n        segmentation (see for example MoveSegAsOneHotToData)\n        \"\"\"\n        self.p_per_label = p_per_label\n        self.strel_size = strel_size\n        self.key = key\n        self.any_of_these = any_of_these\n        self.p_per_sample = p_per_sample\n\n        if not isinstance(channel_idx, (list, tuple)):\n            channel_idx = [channel_idx]\n        self.channel_idx = channel_idx\n\n    def __call__(self, **data_dict):\n        for b in range(data_dict[self.key].shape[0]):\n            if np.random.uniform() < self.p_per_sample:\n                # this needs to be applied in random order to the channels\n                np.random.shuffle(self.channel_idx)\n                for c in self.channel_idx:\n                    if np.random.uniform() < self.p_per_label:\n                        operation = np.random.choice(self.any_of_these)\n                        selem = ball(np.random.uniform(*self.strel_size))\n                        workon = data_dict[self.key][b, c].astype(bool)\n                        if not np.any(workon):\n                            continue\n                        # print(np.unique(workon))\n                        res = operation(workon, selem).astype(data_dict[self.key].dtype)\n                        # print('ApplyRandomBinaryOperatorTransform', c, operation, np.sum(workon), np.sum(res))\n                        data_dict[self.key][b, c] = res\n\n                        # if class was added, we need to remove it in ALL other channels to keep one hot encoding\n                        # properties\n                        other_ch = [i for i in self.channel_idx if i != c]\n                        if len(other_ch) > 0:\n                            was_added_mask = (res - workon) > 0\n                            for oc in other_ch:\n                                data_dict[self.key][b, oc][was_added_mask] = 0\n                            # if class was removed, leave it at background\n        return data_dict\n"
  },
  {
    "path": "nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py",
    "content": "from typing import Tuple, Union, List\n\nfrom batchgenerators.augmentations.utils import resize_segmentation\nfrom batchgenerators.transforms.abstract_transforms import AbstractTransform\nimport numpy as np\n\n\nclass DownsampleSegForDSTransform2(AbstractTransform):\n    '''\n    data_dict['output_key'] will be a list of segmentations scaled according to ds_scales\n    '''\n    def __init__(self, ds_scales: Union[List, Tuple],\n                 order: int = 0, input_key: str = \"seg\",\n                 output_key: str = \"seg\", axes: Tuple[int] = None):\n        \"\"\"\n        Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision\n        output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape.\n        ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling\n        for each axis independently\n        \"\"\"\n        self.axes = axes\n        self.output_key = output_key\n        self.input_key = input_key\n        self.order = order\n        self.ds_scales = ds_scales\n\n    def __call__(self, **data_dict):\n        if self.axes is None:\n            axes = list(range(2, data_dict[self.input_key].ndim))\n        else:\n            axes = self.axes\n\n        output = []\n        for s in self.ds_scales:\n            if not isinstance(s, (tuple, list)):\n                s = [s] * len(axes)\n            else:\n                assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \\\n                                            f'for each axis) then the number of entried in that tuple (here ' \\\n                                            f'{len(s)}) must be the same as the number of axes (here {len(axes)}).'\n\n            if all([i == 1 for i in s]):\n                output.append(data_dict[self.input_key])\n            else:\n                new_shape = np.array(data_dict[self.input_key].shape).astype(float)\n                for i, a in enumerate(axes):\n                    new_shape[a] *= s[i]\n                new_shape = np.round(new_shape).astype(int)\n                out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype)\n                for b in range(data_dict[self.input_key].shape[0]):\n                    for c in range(data_dict[self.input_key].shape[1]):\n                        out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order)\n                output.append(out_seg)\n        data_dict[self.output_key] = output\n        return data_dict\n"
  },
  {
    "path": "nnunetv2/training/data_augmentation/custom_transforms/masking.py",
    "content": "from typing import List\n\nfrom batchgenerators.transforms.abstract_transforms import AbstractTransform\n\n\nclass MaskTransform(AbstractTransform):\n    def __init__(self, apply_to_channels: List[int], mask_idx_in_seg: int = 0, set_outside_to: int = 0,\n                 data_key: str = \"data\", seg_key: str = \"seg\"):\n        \"\"\"\n        Sets everything outside the mask to 0. CAREFUL! outside is defined as < 0, not =0 (in the Mask)!!!\n        \"\"\"\n        self.apply_to_channels = apply_to_channels\n        self.seg_key = seg_key\n        self.data_key = data_key\n        self.set_outside_to = set_outside_to\n        self.mask_idx_in_seg = mask_idx_in_seg\n\n    def __call__(self, **data_dict):\n        mask = data_dict[self.seg_key][:, self.mask_idx_in_seg] < 0\n        for c in self.apply_to_channels:\n            data_dict[self.data_key][:, c][mask] = self.set_outside_to\n        return data_dict\n"
  },
  {
    "path": "nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py",
    "content": "from typing import List, Tuple, Union\n\nfrom batchgenerators.transforms.abstract_transforms import AbstractTransform\nimport numpy as np\n\n\nclass ConvertSegmentationToRegionsTransform(AbstractTransform):\n    def __init__(self, regions: Union[List, Tuple],\n                 seg_key: str = \"seg\", output_key: str = \"seg\", seg_channel: int = 0):\n        \"\"\"\n        regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region,\n        example:\n        regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2\n        :param regions:\n        :param seg_key:\n        :param output_key:\n        \"\"\"\n        self.seg_channel = seg_channel\n        self.output_key = output_key\n        self.seg_key = seg_key\n        self.regions = regions\n\n    def __call__(self, **data_dict):\n        seg = data_dict.get(self.seg_key)\n        if seg is not None:\n            b, c, *shape = seg.shape\n            region_output = np.zeros((b, len(self.regions), *shape), dtype=bool)\n            for region_id, region_labels in enumerate(self.regions):\n                region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels)\n            data_dict[self.output_key] = region_output.astype(np.uint8, copy=False)\n        return data_dict\n\n"
  },
  {
    "path": "nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py",
    "content": "from typing import Tuple, Union, List\n\nfrom batchgenerators.transforms.abstract_transforms import AbstractTransform\n\n\nclass Convert3DTo2DTransform(AbstractTransform):\n    def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')):\n        \"\"\"\n        Transforms a 5D array (b, c, x, y, z) to a 4D array (b, c * x, y, z) by overloading the color channel\n        \"\"\"\n        self.apply_to_keys = apply_to_keys\n\n    def __call__(self, **data_dict):\n        for k in self.apply_to_keys:\n            shp = data_dict[k].shape\n            assert len(shp) == 5, 'This transform only works on 3D data, so expects 5D tensor (b, c, x, y, z) as input.'\n            data_dict[k] = data_dict[k].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4]))\n            shape_key = f'orig_shape_{k}'\n            assert shape_key not in data_dict.keys(), f'Convert3DTo2DTransform needs to store the original shape. ' \\\n                                                      f'It does that using the {shape_key} key. That key is ' \\\n                                                      f'already taken. Bummer.'\n            data_dict[shape_key] = shp\n        return data_dict\n\n\nclass Convert2DTo3DTransform(AbstractTransform):\n    def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')):\n        \"\"\"\n        Reverts Convert3DTo2DTransform by transforming a 4D array (b, c * x, y, z) back to 5D  (b, c, x, y, z)\n        \"\"\"\n        self.apply_to_keys = apply_to_keys\n\n    def __call__(self, **data_dict):\n        for k in self.apply_to_keys:\n            shape_key = f'orig_shape_{k}'\n            assert shape_key in data_dict.keys(), f'Did not find key {shape_key} in data_dict. Shitty. ' \\\n                                                  f'Convert2DTo3DTransform only works in tandem with ' \\\n                                                  f'Convert3DTo2DTransform and you probably forgot to add ' \\\n                                                  f'Convert3DTo2DTransform to your pipeline. (Convert3DTo2DTransform ' \\\n                                                  f'is where the missing key is generated)'\n            original_shape = data_dict[shape_key]\n            current_shape = data_dict[k].shape\n            data_dict[k] = data_dict[k].reshape((original_shape[0], original_shape[1], original_shape[2],\n                                                 current_shape[-2], current_shape[-1]))\n        return data_dict\n"
  },
  {
    "path": "nnunetv2/training/dataloading/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/dataloading/data_loader.py",
    "content": "import os\nimport warnings\nfrom typing import Union, Tuple, List\n\nimport numpy as np\nimport torch\nfrom batchgenerators.dataloading.data_loader import DataLoader\nfrom batchgenerators.utilities.file_and_folder_operations import join, load_json\nfrom threadpoolctl import threadpool_limits\n\nfrom nnunetv2.paths import nnUNet_preprocessed\nfrom nnunetv2.training.dataloading.nnunet_dataset import nnUNetBaseDataset\nfrom nnunetv2.training.dataloading.nnunet_dataset import nnUNetDatasetBlosc2\nfrom nnunetv2.utilities.label_handling.label_handling import LabelManager\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager\nfrom acvl_utils.cropping_and_padding.bounding_boxes import crop_and_pad_nd\n\n\nclass nnUNetDataLoader(DataLoader):\n    def __init__(self,\n                 data: nnUNetBaseDataset,\n                 batch_size: int,\n                 patch_size: Union[List[int], Tuple[int, ...], np.ndarray],\n                 final_patch_size: Union[List[int], Tuple[int, ...], np.ndarray],\n                 label_manager: LabelManager,\n                 oversample_foreground_percent: float = 0.0,\n                 sampling_probabilities: Union[List[int], Tuple[int, ...], np.ndarray] = None,\n                 pad_sides: Union[List[int], Tuple[int, ...]] = None,\n                 probabilistic_oversampling: bool = False,\n                 transforms=None):\n        \"\"\"\n        If we get a 2D patch size, make it pseudo 3D and remember to remove the singleton dimension before\n        returning the batch\n        \"\"\"\n        super().__init__(data, batch_size, 1, None, True,\n                         False, True, sampling_probabilities)\n\n        if len(patch_size) == 2:\n            final_patch_size = (1, *final_patch_size)\n            patch_size = (1, *patch_size)\n            self.patch_size_was_2d = True\n        else:\n            self.patch_size_was_2d = False\n\n        # this is used by DataLoader for sampling train cases!\n        self.indices = data.identifiers\n\n        self.oversample_foreground_percent = oversample_foreground_percent\n        self.final_patch_size = final_patch_size\n        self.patch_size = patch_size\n        # need_to_pad denotes by how much we need to pad the data so that if we sample a patch of size final_patch_size\n        # (which is what the network will get) these patches will also cover the border of the images\n        self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int)\n        if pad_sides is not None:\n            if self.patch_size_was_2d:\n                pad_sides = (0, *pad_sides)\n            for d in range(len(self.need_to_pad)):\n                self.need_to_pad[d] += pad_sides[d]\n        self.num_channels = None\n        self.pad_sides = pad_sides\n        self.sampling_probabilities = sampling_probabilities\n        self.annotated_classes_key = tuple([-1] + label_manager.all_labels)\n        self.has_ignore = label_manager.has_ignore_label\n        self.get_do_oversample = self._oversample_last_XX_percent if not probabilistic_oversampling \\\n            else self._probabilistic_oversampling\n        self.transforms = transforms\n        self.data_shape, self.seg_shape = self.determine_shapes()\n\n    def _oversample_last_XX_percent(self, sample_idx: int) -> bool:\n        \"\"\"\n        determines whether sample sample_idx in a minibatch needs to be guaranteed foreground\n        \"\"\"\n        return not sample_idx < round(self.batch_size * (1 - self.oversample_foreground_percent))\n\n    def _probabilistic_oversampling(self, sample_idx: int) -> bool:\n        # print('YEAH BOIIIIII')\n        return np.random.uniform() < self.oversample_foreground_percent\n\n    def determine_shapes(self):\n        # load one case\n        data, seg, seg_prev, properties = self._data.load_case(self._data.identifiers[0])\n        num_color_channels = data.shape[0]\n\n        if self.patch_size_was_2d:\n            spatial_shape = self.final_patch_size[1:] if self.transforms is not None else self.patch_size[1:]\n        else:\n            spatial_shape = self.final_patch_size if self.transforms is not None else self.patch_size\n\n        data_shape = (self.batch_size, num_color_channels, *spatial_shape)\n        channels_seg = seg.shape[0]\n        if seg_prev is not None:\n            channels_seg += 1\n        seg_shape = (self.batch_size, channels_seg, *spatial_shape)\n        return data_shape, seg_shape\n\n    def get_bbox(self, data_shape: np.ndarray, force_fg: bool, class_locations: Union[dict, None],\n                 overwrite_class: Union[int, Tuple[int, ...]] = None, verbose: bool = False):\n        # in dataloader 2d we need to select the slice prior to this and also modify the class_locations to only have\n        # locations for the given slice\n        need_to_pad = self.need_to_pad.copy()\n        dim = len(data_shape)\n\n        for d in range(dim):\n            # if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides\n            # always\n            if need_to_pad[d] + data_shape[d] < self.patch_size[d]:\n                need_to_pad[d] = self.patch_size[d] - data_shape[d]\n\n        # we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we\n        # define what the upper and lower bound can be to then sample form them with np.random.randint\n        lbs = [- need_to_pad[i] // 2 for i in range(dim)]\n        ubs = [data_shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - self.patch_size[i] for i in range(dim)]\n\n        # if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get\n        # at least one of the foreground classes in the patch\n        if not force_fg and not self.has_ignore:\n            bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]\n            # print('I want a random location')\n        else:\n            if not force_fg and self.has_ignore:\n                selected_class = self.annotated_classes_key\n                if len(class_locations[selected_class]) == 0:\n                    # no annotated pixels in this case. Not good. But we can hardly skip it here\n                    warnings.warn('Warning! No annotated pixels in image!')\n                    selected_class = None\n            elif force_fg:\n                assert class_locations is not None, 'if force_fg is set class_locations cannot be None'\n                if overwrite_class is not None:\n                    assert overwrite_class in class_locations.keys(), 'desired class (\"overwrite_class\") does not ' \\\n                                                                      'have class_locations (missing key)'\n                # this saves us a np.unique. Preprocessing already did that for all cases. Neat.\n                # class_locations keys can also be tuple\n                eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0]\n\n                # if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list\n                # strange formulation needed to circumvent\n                # ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()\n                tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions]\n                if any(tmp):\n                    if len(eligible_classes_or_regions) > 1:\n                        eligible_classes_or_regions.pop(np.where(tmp)[0][0])\n\n                if len(eligible_classes_or_regions) == 0:\n                    # this only happens if some image does not contain foreground voxels at all\n                    selected_class = None\n                    if verbose:\n                        print('case does not contain any foreground classes')\n                else:\n                    # I hate myself. Future me aint gonna be happy to read this\n                    # 2022_11_25: had to read it today. Wasn't too bad\n                    selected_class = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \\\n                        (overwrite_class is None or (overwrite_class not in eligible_classes_or_regions)) else overwrite_class\n                # print(f'I want to have foreground, selected class: {selected_class}')\n            else:\n                raise RuntimeError('lol what!?')\n\n            if selected_class is not None:\n                voxels_of_that_class = class_locations[selected_class]\n                selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]\n                # selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel.\n                # Make sure it is within the bounds of lb and ub\n                # i + 1 because we have first dimension 0!\n                bbox_lbs = [max(lbs[i], selected_voxel[i + 1] - self.patch_size[i] // 2) for i in range(dim)]\n            else:\n                # If the image does not contain any foreground classes, we fall back to random cropping\n                bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]\n\n        bbox_ubs = [bbox_lbs[i] + self.patch_size[i] for i in range(dim)]\n\n        return bbox_lbs, bbox_ubs\n\n    def generate_train_batch(self):\n        selected_keys = self.get_indices()\n        # preallocate output tensors in final patch size and write transformed samples directly\n        data_all = torch.empty(self.data_shape, dtype=torch.float32)\n        seg_all = None\n\n        with torch.no_grad():\n            with threadpool_limits(limits=1, user_api=None):\n                for j, i in enumerate(selected_keys):\n                    # oversampling foreground will improve stability of model training, especially if many patches are empty\n                    # (Lung for example)\n                    force_fg = self.get_do_oversample(j)\n\n                    data, seg, seg_prev, properties = self._data.load_case(i)\n\n                    # If we are doing the cascade then the segmentation from the previous stage will already have been loaded by\n                    # self._data.load_case(i) (see nnUNetDataset.load_case)\n                    shape = data.shape[1:]\n\n                    bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg, properties['class_locations'])\n                    bbox = [[i, j] for i, j in zip(bbox_lbs, bbox_ubs)]\n\n                    data_cropped = torch.from_numpy(crop_and_pad_nd(data, bbox, 0)).float()\n                    seg_cropped = torch.from_numpy(crop_and_pad_nd(seg, bbox, -1)).to(torch.int16)\n                    if seg_prev is not None:\n                        seg_prev_cropped = torch.from_numpy(crop_and_pad_nd(seg_prev, bbox, -1)).to(torch.int16)\n                        seg_cropped = torch.cat((seg_cropped, seg_prev_cropped[None]), dim=0)\n\n                    if self.patch_size_was_2d:\n                        data_cropped = data_cropped[:, 0]\n                        seg_cropped = seg_cropped[:, 0]\n\n                    if self.transforms is not None:\n                        transformed = self.transforms(**{'image': data_cropped, 'segmentation': seg_cropped})\n                        data_sample = transformed['image']\n                        seg_sample = transformed['segmentation']\n                    else:\n                        data_sample = data_cropped\n                        seg_sample = seg_cropped\n\n                    data_all[j] = data_sample\n\n                    if isinstance(seg_sample, list):\n                        if seg_all is None:\n                            seg_all = [torch.empty((self.batch_size, *s.shape), dtype=s.dtype) for s in seg_sample]\n                        for s_idx, s in enumerate(seg_sample):\n                            seg_all[s_idx][j] = s\n                    else:\n                        if seg_all is None:\n                            seg_all = torch.empty((self.batch_size, *seg_sample.shape), dtype=seg_sample.dtype)\n                        seg_all[j] = seg_sample\n\n        return {'data': data_all, 'target': seg_all, 'keys': selected_keys}\n\n\nif __name__ == '__main__':\n    folder = join(nnUNet_preprocessed, 'Dataset002_Heart', 'nnUNetPlans_3d_fullres')\n    ds = nnUNetDatasetBlosc2(folder)  # this should not load the properties!\n    pm = PlansManager(join(folder, os.pardir, 'nnUNetPlans.json'))\n    lm = pm.get_label_manager(load_json(join(folder, os.pardir, 'dataset.json')))\n    dl = nnUNetDataLoader(ds, 5, (16, 16, 16), (16, 16, 16), lm,\n                          0.33, None, None)\n    a = next(dl)\n"
  },
  {
    "path": "nnunetv2/training/dataloading/nnunet_dataset.py",
    "content": "import os\nimport warnings\nfrom abc import ABC, abstractmethod\nfrom copy import deepcopy\nfrom functools import lru_cache\nfrom typing import List, Union, Type, Tuple\n\nimport numpy as np\nimport blosc2\nimport shutil\nfrom blosc2 import Filter, Codec\n\nfrom batchgenerators.utilities.file_and_folder_operations import join, load_pickle, isfile, write_pickle, subfiles\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.training.dataloading.utils import unpack_dataset\nimport math\n\n\nclass nnUNetBaseDataset(ABC):\n    \"\"\"\n    Defines the interface\n    \"\"\"\n    def __init__(self, folder: str, identifiers: List[str] = None,\n                 folder_with_segs_from_previous_stage: str = None):\n        super().__init__()\n        # print('loading dataset')\n        if identifiers is None:\n            identifiers = self.get_identifiers(folder)\n        identifiers.sort()\n\n        self.source_folder = folder\n        self.folder_with_segs_from_previous_stage = folder_with_segs_from_previous_stage\n        self.identifiers = identifiers\n\n    def __getitem__(self, identifier):\n        return self.load_case(identifier)\n\n    @abstractmethod\n    def load_case(self, identifier):\n        pass\n\n    @staticmethod\n    @abstractmethod\n    def save_case(\n            data: np.ndarray,\n            seg: np.ndarray,\n            properties: dict,\n            output_filename_truncated: str\n            ):\n        pass\n\n    @staticmethod\n    @abstractmethod\n    def get_identifiers(folder: str) -> List[str]:\n        pass\n\n    @staticmethod\n    def unpack_dataset(folder: str, overwrite_existing: bool = False,\n                       num_processes: int = default_num_processes,\n                       verify: bool = True):\n        pass\n\n\nclass nnUNetDatasetNumpy(nnUNetBaseDataset):\n    def load_case(self, identifier):\n        data_npy_file = join(self.source_folder, identifier + '.npy')\n        if not isfile(data_npy_file):\n            data = np.load(join(self.source_folder, identifier + '.npz'))['data']\n        else:\n            data = np.load(data_npy_file, mmap_mode='r')\n\n        seg_npy_file = join(self.source_folder, identifier + '_seg.npy')\n        if not isfile(seg_npy_file):\n            seg = np.load(join(self.source_folder, identifier + '.npz'))['seg']\n        else:\n            seg = np.load(seg_npy_file, mmap_mode='r')\n\n        if self.folder_with_segs_from_previous_stage is not None:\n            prev_seg_npy_file = join(self.folder_with_segs_from_previous_stage, identifier + '.npy')\n            if isfile(prev_seg_npy_file):\n                seg_prev = np.load(prev_seg_npy_file, 'r')\n            else:\n                seg_prev = np.load(join(self.folder_with_segs_from_previous_stage, identifier + '.npz'))['seg']\n        else:\n            seg_prev = None\n\n        properties = load_pickle(join(self.source_folder, identifier + '.pkl'))\n        return data, seg, seg_prev, properties\n\n    @staticmethod\n    def save_case(\n            data: np.ndarray,\n            seg: np.ndarray,\n            properties: dict,\n            output_filename_truncated: str\n    ):\n        np.savez_compressed(output_filename_truncated + '.npz', data=data, seg=seg)\n        write_pickle(properties, output_filename_truncated + '.pkl')\n\n    @staticmethod\n    def save_seg(\n            seg: np.ndarray,\n            output_filename_truncated: str\n    ):\n        np.savez_compressed(output_filename_truncated + '.npz', seg=seg)\n\n    @staticmethod\n    def get_identifiers(folder: str) -> List[str]:\n        \"\"\"\n        returns all identifiers in the preprocessed data folder\n        \"\"\"\n        case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith(\"npz\")]\n        return case_identifiers\n\n    @staticmethod\n    def unpack_dataset(folder: str, overwrite_existing: bool = False,\n                       num_processes: int = default_num_processes,\n                       verify: bool = True):\n        return unpack_dataset(folder, True, overwrite_existing, num_processes, verify)\n\n\nclass nnUNetDatasetBlosc2(nnUNetBaseDataset):\n    def __init__(self, folder: str, identifiers: List[str] = None,\n                 folder_with_segs_from_previous_stage: str = None):\n        super().__init__(folder, identifiers, folder_with_segs_from_previous_stage)\n        blosc2.set_nthreads(1)\n\n    def __getitem__(self, identifier):\n        return self.load_case(identifier)\n\n    def load_case(self, identifier):\n        dparams = {\n            'nthreads': 1\n        }\n        data_b2nd_file = join(self.source_folder, identifier + '.b2nd')\n\n        # mmap does not work with Windows -> https://github.com/MIC-DKFZ/nnUNet/issues/2723\n        mmap_kwargs = {} if os.name == \"nt\" else {'mmap_mode': 'r'}\n        data = blosc2.open(urlpath=data_b2nd_file, mode='r', dparams=dparams, **mmap_kwargs)\n\n        seg_b2nd_file = join(self.source_folder, identifier + '_seg.b2nd')\n        seg = blosc2.open(urlpath=seg_b2nd_file, mode='r', dparams=dparams, **mmap_kwargs)\n\n        if self.folder_with_segs_from_previous_stage is not None:\n            prev_seg_b2nd_file = join(self.folder_with_segs_from_previous_stage, identifier + '.b2nd')\n            seg_prev = blosc2.open(urlpath=prev_seg_b2nd_file, mode='r', dparams=dparams, **mmap_kwargs)\n        else:\n            seg_prev = None\n\n        properties = load_pickle(join(self.source_folder, identifier + '.pkl'))\n        return data, seg, seg_prev, properties\n\n    @staticmethod\n    def save_case(\n            data: np.ndarray,\n            seg: np.ndarray,\n            properties: dict,\n            output_filename_truncated: str,\n            chunks=None,\n            blocks=None,\n            chunks_seg=None,\n            blocks_seg=None,\n            clevel: int = 8,\n            codec=blosc2.Codec.ZSTD\n    ):\n        blosc2.set_nthreads(1)\n        if chunks_seg is None:\n            chunks_seg = chunks\n        if blocks_seg is None:\n            blocks_seg = blocks\n\n        cparams = {\n            'codec': codec,\n            # 'filters': [blosc2.Filter.SHUFFLE],\n            # 'splitmode': blosc2.SplitMode.ALWAYS_SPLIT,\n            'clevel': clevel,\n        }\n        # print(output_filename_truncated, data.shape, seg.shape, blocks, chunks, blocks_seg, chunks_seg, data.dtype, seg.dtype)\n        blosc2.asarray(np.ascontiguousarray(data), urlpath=output_filename_truncated + '.b2nd', chunks=chunks,\n                       blocks=blocks, cparams=cparams)\n        blosc2.asarray(np.ascontiguousarray(seg), urlpath=output_filename_truncated + '_seg.b2nd', chunks=chunks_seg,\n                       blocks=blocks_seg, cparams=cparams)\n        write_pickle(properties, output_filename_truncated + '.pkl')\n\n    @staticmethod\n    def save_seg(\n            seg: np.ndarray,\n            output_filename_truncated: str,\n            chunks_seg=None,\n            blocks_seg=None\n    ):\n        blosc2.asarray(seg, urlpath=output_filename_truncated + '.b2nd', chunks=chunks_seg, blocks=blocks_seg)\n\n    @staticmethod\n    def get_identifiers(folder: str) -> List[str]:\n        \"\"\"\n        returns all identifiers in the preprocessed data folder\n        \"\"\"\n        case_identifiers = [i[:-5] for i in os.listdir(folder) if i.endswith(\".b2nd\") and not i.endswith(\"_seg.b2nd\")]\n        return case_identifiers\n\n    @staticmethod\n    def unpack_dataset(folder: str, overwrite_existing: bool = False,\n                       num_processes: int = default_num_processes,\n                       verify: bool = True):\n        pass\n\n    @staticmethod\n    def comp_blosc2_params(\n            image_size: Tuple[int, int, int, int],\n            patch_size: Union[Tuple[int, int], Tuple[int, int, int]],\n            bytes_per_pixel: int = 4,  # 4 byte are float32\n            l1_cache_size_per_core_in_bytes=32768,  # 1 Kibibyte (KiB) = 2^10 Byte;  32 KiB = 32768 Byte\n            l3_cache_size_per_core_in_bytes=1441792,\n            # 1 Mibibyte (MiB) = 2^20 Byte = 1.048.576 Byte; 1.375MiB = 1441792 Byte\n            safety_factor: float = 0.8  # we dont will the caches to the brim. 0.8 means we target 80% of the caches\n    ):\n        \"\"\"\n        Computes a recommended block and chunk size for saving arrays with blosc v2.\n\n        Bloscv2 NDIM doku: \"Remember that having a second partition means that we have better flexibility to fit the\n        different partitions at the different CPU cache levels; typically the first partition (aka chunks) should\n        be made to fit in L3 cache, whereas the second partition (aka blocks) should rather fit in L2/L1 caches\n        (depending on whether compression ratio or speed is desired).\"\n        (https://www.blosc.org/posts/blosc2-ndim-intro/)\n        -> We are not 100% sure how to optimize for that. For now we try to fit the uncompressed block in L1. This\n        might spill over into L2, which is fine in our books.\n\n        Note: this is optimized for nnU-Net dataloading where each read operation is done by one core. We cannot use threading\n\n        Cache default values computed based on old Intel 4110 CPU with 32K L1, 128K L2 and 1408K L3 cache per core.\n        We cannot optimize further for more modern CPUs with more cache as the data will need be be read by the\n        old ones as well.\n\n        Args:\n            patch_size: Image size, must be 4D (c, x, y, z). For 2D images, make x=1\n            patch_size: Patch size, spatial dimensions only. So (x, y) or (x, y, z)\n            bytes_per_pixel: Number of bytes per element. Example: float32 -> 4 bytes\n            l1_cache_size_per_core_in_bytes: The size of the L1 cache per core in Bytes.\n            l3_cache_size_per_core_in_bytes: The size of the L3 cache exclusively accessible by each core. Usually the global size of the L3 cache divided by the number of cores.\n\n        Returns:\n            The recommended block and the chunk size.\n        \"\"\"\n        # Fabians code is ugly, but eh\n\n        num_channels = image_size[0]\n        if len(patch_size) == 2:\n            patch_size = [1, *patch_size]\n        patch_size = np.array(patch_size)\n        block_size = np.array((num_channels, *[2 ** (max(0, math.ceil(math.log2(i)))) for i in patch_size]))\n\n        # shrink the block size until it fits in L1\n        estimated_nbytes_block = np.prod(block_size) * bytes_per_pixel\n        while estimated_nbytes_block > (l1_cache_size_per_core_in_bytes * safety_factor):\n            # pick largest deviation from patch_size that is not 1\n            axis_order = np.argsort(block_size[1:] / patch_size)[::-1]\n            idx = 0\n            picked_axis = axis_order[idx]\n            while block_size[picked_axis + 1] == 1 or block_size[picked_axis + 1] == 1:\n                idx += 1\n                picked_axis = axis_order[idx]\n            # now reduce that axis to the next lowest power of 2\n            block_size[picked_axis + 1] = 2 ** (max(0, math.floor(math.log2(block_size[picked_axis + 1] - 1))))\n            block_size[picked_axis + 1] = min(block_size[picked_axis + 1], image_size[picked_axis + 1])\n            estimated_nbytes_block = np.prod(block_size) * bytes_per_pixel\n\n        block_size = np.array([min(i, j) for i, j in zip(image_size, block_size)])\n\n        # note: there is no use extending the chunk size to 3d when we have a 2d patch size! This would unnecessarily\n        # load data into L3\n        # now tile the blocks into chunks until we hit image_size or the l3 cache per core limit\n        chunk_size = deepcopy(block_size)\n        estimated_nbytes_chunk = np.prod(chunk_size) * bytes_per_pixel\n        while estimated_nbytes_chunk < (l3_cache_size_per_core_in_bytes * safety_factor):\n            if patch_size[0] == 1 and all([i == j for i, j in zip(chunk_size[2:], image_size[2:])]):\n                break\n            if all([i == j for i, j in zip(chunk_size, image_size)]):\n                break\n            # find axis that deviates from block_size the most\n            axis_order = np.argsort(chunk_size[1:] / block_size[1:])\n            idx = 0\n            picked_axis = axis_order[idx]\n            while chunk_size[picked_axis + 1] == image_size[picked_axis + 1] or patch_size[picked_axis] == 1:\n                idx += 1\n                picked_axis = axis_order[idx]\n            chunk_size[picked_axis + 1] += block_size[picked_axis + 1]\n            chunk_size[picked_axis + 1] = min(chunk_size[picked_axis + 1], image_size[picked_axis + 1])\n            estimated_nbytes_chunk = np.prod(chunk_size) * bytes_per_pixel\n            if np.mean([i / j for i, j in zip(chunk_size[1:], patch_size)]) > 1.5:\n                # chunk size should not exceed patch size * 1.5 on average\n                chunk_size[picked_axis + 1] -= block_size[picked_axis + 1]\n                break\n        # better safe than sorry\n        chunk_size = [min(i, j) for i, j in zip(image_size, chunk_size)]\n\n        # print(image_size, chunk_size, block_size)\n        return tuple(block_size), tuple(chunk_size)\n\n\nfile_ending_dataset_mapping = {\n    'npz': nnUNetDatasetNumpy,\n    'b2nd': nnUNetDatasetBlosc2\n}\n\n\ndef infer_dataset_class(folder: str) -> Union[Type[nnUNetDatasetBlosc2], Type[nnUNetDatasetNumpy]]:\n    file_endings = set([os.path.basename(i).split('.')[-1] for i in subfiles(folder, join=False)])\n    if 'pkl' in file_endings:\n        file_endings.remove('pkl')\n    if 'npy' in file_endings:\n        file_endings.remove('npy')\n    assert len(file_endings) == 1, (f'Found more than one file ending in the folder {folder}. '\n                                    f'Unable to infer nnUNetDataset variant!')\n    return file_ending_dataset_mapping[list(file_endings)[0]]\n"
  },
  {
    "path": "nnunetv2/training/dataloading/utils.py",
    "content": "from __future__ import annotations\nimport multiprocessing\nimport os\nfrom typing import List\nfrom pathlib import Path\nfrom warnings import warn\n\nimport numpy as np\nfrom batchgenerators.utilities.file_and_folder_operations import isfile, subfiles\nfrom nnunetv2.configuration import default_num_processes\n\n\ndef _convert_to_npy(npz_file: str, unpack_segmentation: bool = True, overwrite_existing: bool = False,\n                    verify_npy: bool = False, fail_ctr: int = 0) -> None:\n    data_npy = npz_file[:-3] + \"npy\"\n    seg_npy = npz_file[:-4] + \"_seg.npy\"\n    try:\n        npz_content = None  # will only be opened on demand\n\n        if overwrite_existing or not isfile(data_npy):\n            try:\n                npz_content = np.load(npz_file) if npz_content is None else npz_content\n            except Exception as e:\n                print(f\"Unable to open preprocessed file {npz_file}. Rerun nnUNetv2_preprocess!\")\n                raise e\n            np.save(data_npy, npz_content['data'])\n\n        if unpack_segmentation and (overwrite_existing or not isfile(seg_npy)):\n            try:\n                npz_content = np.load(npz_file) if npz_content is None else npz_content\n            except Exception as e:\n                print(f\"Unable to open preprocessed file {npz_file}. Rerun nnUNetv2_preprocess!\")\n                raise e\n            np.save(npz_file[:-4] + \"_seg.npy\", npz_content['seg'])\n\n        if verify_npy:\n            try:\n                np.load(data_npy, mmap_mode='r')\n                if isfile(seg_npy):\n                    np.load(seg_npy, mmap_mode='r')\n            except ValueError:\n                os.remove(data_npy)\n                os.remove(seg_npy)\n                print(f\"Error when checking {data_npy} and {seg_npy}, fixing...\")\n                if fail_ctr < 2:\n                    _convert_to_npy(npz_file, unpack_segmentation, overwrite_existing, verify_npy, fail_ctr+1)\n                else:\n                    raise RuntimeError(\"Unable to fix unpacking. Please check your system or rerun nnUNetv2_preprocess\")\n\n    except KeyboardInterrupt:\n        if isfile(data_npy):\n            os.remove(data_npy)\n        if isfile(seg_npy):\n            os.remove(seg_npy)\n        raise KeyboardInterrupt\n\n\ndef unpack_dataset(folder: str, unpack_segmentation: bool = True, overwrite_existing: bool = False,\n                   num_processes: int = default_num_processes,\n                   verify: bool = False):\n    \"\"\"\n    all npz files in this folder belong to the dataset, unpack them all\n    \"\"\"\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as p:\n        npz_files = subfiles(folder, True, None, \".npz\", True)\n        p.starmap(_convert_to_npy, zip(npz_files,\n                                       [unpack_segmentation] * len(npz_files),\n                                       [overwrite_existing] * len(npz_files),\n                                       [verify] * len(npz_files))\n                  )\n\n\nif __name__ == '__main__':\n    unpack_dataset('/media/fabian/data/nnUNet_preprocessed/Dataset002_Heart/2d')"
  },
  {
    "path": "nnunetv2/training/logging/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/logging/nnunet_logger.py",
    "content": "import matplotlib\nfrom batchgenerators.utilities.file_and_folder_operations import join\n\nmatplotlib.use('agg')\nimport seaborn as sns\nimport matplotlib.pyplot as plt\nfrom typing import Any\nfrom pathlib import Path\nimport shutil\nimport os\n\ntry:\n    import wandb\nexcept ImportError:\n    wandb = None\n\n\ndef get_cluster_job_id():\n    job_id = None\n    if \"LSB_JOBID\" in os.environ:\n        job_id = os.environ[\"LSB_JOBID\"]\n    if \"SLURM_JOB_ID\" in os.environ:\n        job_id = os.environ[\"SLURM_JOB_ID\"]\n    return job_id\n\n\nclass MetaLogger(object):\n    \"\"\"A meta logger that bundles multiple loggers behind a single interface.\n\n    The default configuration includes a local logger used for reading values,\n    plotting progress, and checkpointing.\n    \"\"\"\n\n    def __init__(self, output_folder, resume, verbose: bool = False):\n        \"\"\"Initialize the meta logger.\n\n        Args:\n            output_folder: The output folder.\n            resume: Whether to resume training if possible.\n            verbose: Whether to enable verbose logging in the local logger.\n        \"\"\"\n        self.output_folder = output_folder\n        self.resume = resume\n        self.loggers = []\n        self.local_logger = LocalLogger(verbose)\n        if self._is_logger_enabled(\"nnUNet_wandb_enabled\"):\n            self.loggers.append(WandbLogger(output_folder, resume))\n\n    def update_config(self, config: dict):\n        \"\"\"Add a new or update an existing experiment configuration to the logger.\n\n        Args:\n            config: Logger configuration options.\n        \"\"\"\n        for logger in self.loggers:\n            logger.update_config(config)\n\n    def log(self, key: str, value: Any, step: int):\n        \"\"\"Log a value for a given step.\n\n        Args:\n            key: Metric or field name.\n            value: Value to log.\n            step: Step index (typically epoch).\n        \"\"\"\n        self.local_logger.log(key, value, step)\n        if isinstance(value, list):\n            for i, val in enumerate(value):\n                for logger in self.loggers:\n                    logger.log(f\"{key}/class_{i+1}\", val, step)\n        else:\n            for logger in self.loggers:\n                logger.log(key, value, step)\n        \n        # handle the ema_fg_dice special case! It is automatically logged when we add a new mean_fg_dice\n        if key == 'mean_fg_dice':\n            new_ema_pseudo_dice = self.get_value('ema_fg_dice', step=step-1) * 0.9 + 0.1 * value \\\n                if len(self.get_value('ema_fg_dice', step=None)) > 0 else value\n            self.log('ema_fg_dice', new_ema_pseudo_dice, step)\n\n    def log_summary(self, key: str, value: Any):\n        \"\"\"Log a summary value. These are usually values that are not logged every step but only once. \n        This can be for example the final validation Dice.\n\n        Args:\n            key: Metric or field name.\n            value: Value to summarize.\n        \"\"\"\n        for logger in self.loggers:\n            logger.log_summary(key, value)\n\n    def get_value(self, key: str, step: Any):\n        \"\"\"Fetch a logged value from the local logger.\n\n        Args:\n            key: Metric or field name.\n            step: Step index to retrieve, or None to return all values.\n\n        Returns:\n            The logged value or list of values from the local logger.\n        \"\"\"\n        return self.local_logger.get_value(key, step)\n\n    def plot_progress_png(self, output_folder: str):\n        \"\"\"Write a progress plot PNG using local logger data.\n\n        Args:\n            output_folder: Directory where the plot image is saved.\n        \"\"\"\n        self.local_logger.plot_progress_png(output_folder)\n\n    def get_checkpoint(self):\n        \"\"\"Return the local logger checkpoint data.\n\n        Returns:\n            The checkpoint payload used to restore logging state.\n        \"\"\"\n        return self.local_logger.get_checkpoint()\n\n    def load_checkpoint(self, checkpoint: dict):\n        \"\"\"Restore the local logger from a checkpoint payload.\n\n        Args:\n            checkpoint: Checkpoint data returned by `get_checkpoint`.\n        \"\"\"\n        self.local_logger.load_checkpoint(checkpoint)\n\n    def _is_logger_enabled(self, env_var):\n        env_var_result = str(os.getenv(env_var, \"0\"))\n        if env_var_result in (\"0\", \"False\", \"false\"):\n            return False\n        elif env_var_result in (\"1\", \"True\", \"true\"):\n            return True\n        else:\n            raise RuntimeError(\"nnU-Net logger environement variable has the wrong value. Must be '0' (disabled) or '1'(enabled).\")\n\n\nclass LocalLogger:\n    \"\"\"\n    This class is really trivial. Don't expect cool functionality here. This is my makeshift solution to problems\n    arising from out-of-sync epoch numbers and numbers of logged loss values. It also simplifies the trainer class a\n    little\n\n    YOU MUST LOG EXACTLY ONE VALUE PER EPOCH FOR EACH OF THE LOGGING ITEMS! DONT FUCK IT UP\n    \"\"\"\n    def __init__(self, verbose: bool = False):\n        self.my_fantastic_logging = {\n            'mean_fg_dice': list(),\n            'ema_fg_dice': list(),\n            'dice_per_class_or_region': list(),\n            'train_losses': list(),\n            'val_losses': list(),\n            'lrs': list(),\n            'epoch_start_timestamps': list(),\n            'epoch_end_timestamps': list()\n        }\n        self.verbose = verbose\n        # shut up, this logging is great\n\n    def log(self, key, value, epoch: int):\n        \"\"\"\n        sometimes shit gets messed up. We try to catch that here\n        \"\"\"\n        assert key in self.my_fantastic_logging.keys() and isinstance(self.my_fantastic_logging[key], list), \\\n            'This function is only intended to log stuff to lists and to have one entry per epoch'\n\n        if self.verbose: print(f'logging {key}: {value} for epoch {epoch}')\n\n        if len(self.my_fantastic_logging[key]) < (epoch + 1):\n            self.my_fantastic_logging[key].append(value)\n        else:\n            assert len(self.my_fantastic_logging[key]) == (epoch + 1), 'something went horribly wrong. My logging ' \\\n                                                                       'lists length is off by more than 1'\n            print(f'maybe some logging issue!? logging {key} and {value}')\n            self.my_fantastic_logging[key][epoch] = value\n\n    def get_value(self, key, step):\n        if step is not None:\n            return self.my_fantastic_logging[key][step]\n        else:\n            return self.my_fantastic_logging[key]\n\n    def plot_progress_png(self, output_folder):\n        # we infer the epoch form our internal logging\n        epoch = min([len(i) for i in self.my_fantastic_logging.values()]) - 1  # lists of epoch 0 have len 1\n        sns.set(font_scale=2.5)\n        fig, ax_all = plt.subplots(3, 1, figsize=(30, 54))\n        # regular progress.png as we are used to from previous nnU-Net versions\n        ax = ax_all[0]\n        ax2 = ax.twinx()\n        x_values = list(range(epoch + 1))\n        ax.plot(x_values, self.my_fantastic_logging['train_losses'][:epoch + 1], color='b', ls='-', label=\"loss_tr\", linewidth=4)\n        ax.plot(x_values, self.my_fantastic_logging['val_losses'][:epoch + 1], color='r', ls='-', label=\"loss_val\", linewidth=4)\n        ax2.plot(x_values, self.my_fantastic_logging['mean_fg_dice'][:epoch + 1], color='g', ls='dotted', label=\"pseudo dice\",\n                 linewidth=3)\n        ax2.plot(x_values, self.my_fantastic_logging['ema_fg_dice'][:epoch + 1], color='g', ls='-', label=\"pseudo dice (mov. avg.)\",\n                 linewidth=4)\n        ax.set_xlabel(\"epoch\")\n        ax.set_ylabel(\"loss\")\n        ax2.set_ylabel(\"pseudo dice\")\n        ax.legend(loc=(0, 1))\n        ax2.legend(loc=(0.2, 1))\n\n        # epoch times to see whether the training speed is consistent (inconsistent means there are other jobs\n        # clogging up the system)\n        ax = ax_all[1]\n        ax.plot(x_values, [i - j for i, j in zip(self.my_fantastic_logging['epoch_end_timestamps'][:epoch + 1],\n                                                 self.my_fantastic_logging['epoch_start_timestamps'])][:epoch + 1], color='b',\n                ls='-', label=\"epoch duration\", linewidth=4)\n        ylim = [0] + [ax.get_ylim()[1]]\n        ax.set(ylim=ylim)\n        ax.set_xlabel(\"epoch\")\n        ax.set_ylabel(\"time [s]\")\n        ax.legend(loc=(0, 1))\n\n        # learning rate\n        ax = ax_all[2]\n        ax.plot(x_values, self.my_fantastic_logging['lrs'][:epoch + 1], color='b', ls='-', label=\"learning rate\", linewidth=4)\n        ax.set_xlabel(\"epoch\")\n        ax.set_ylabel(\"learning rate\")\n        ax.legend(loc=(0, 1))\n\n        plt.tight_layout()\n\n        fig.savefig(join(output_folder, \"progress.png\"))\n        plt.close()\n\n    def get_checkpoint(self):\n        return self.my_fantastic_logging\n\n    def load_checkpoint(self, checkpoint: dict):\n        self.my_fantastic_logging = checkpoint\n\n\nclass WandbLogger:\n    \"\"\"Weights & Biases logger for nnU-Net training runs.\n\n    Environment Variables:\n        nnUNet_wandb_enabled: Whether W&B logger is enabled (default: 0 -> Disabled)\n        nnUNet_wandb_project: W&B project name (default: \"nnunet\").\n        nnUNet_wandb_mode: W&B mode, e.g. \"online\" or \"offline\" (default: \"online\").\n    \"\"\"\n\n    def __init__(self, output_folder, resume):\n        \"\"\"Initialize a W&B run and handle resume behavior.\n\n        Args:\n            output_folder: Directory where W&B run data is stored.\n            resume: Whether to resume a previous W&B run if present.\n            verbose: Unused verbosity flag (kept for interface compatibility).\n        \"\"\"\n        if wandb is None:\n            raise RuntimeError(\"W&B is not installed. Please install W&B with 'pip install wandb' before using the WandbLogger.\")\n\n        self.output_folder = Path(output_folder)\n        self.resume = resume\n        self.project = os.getenv(\"nnUNet_wandb_project\", \"nnunet\")\n        self.mode = os.getenv(\"nnUNet_wandb_mode\", \"online\")\n\n        wandb_id = None\n        if (self.output_folder / \"wandb\").is_dir():\n            if self.resume:\n                wandb_dir = self.output_folder / \"wandb\" / \"latest-run\"\n                wandb_filename = next(filename for filename in wandb_dir.iterdir() if filename.suffix == \".wandb\")\n                wandb_id = wandb_filename.stem[4:]\n            else:\n                shutil.rmtree(str(self.output_folder / \"wandb\"))\n\n        _resume = \"allow\" if self.resume else \"never\"\n        self.run = wandb.init(project=self.project, dir=str(self.output_folder), id=wandb_id, mode=self.mode, resume=_resume)\n        self.run.config.update({\"JobID\": get_cluster_job_id()})\n        self.wandb_init_step = self.run.step\n    \n    def update_config(self, config: dict):\n        \"\"\"Update W&B config with training metadata.\n\n        Args:\n            config: Configuration values to merge into the run config.\n        \"\"\"\n        self.run.config.update(config)\n\n    def log(self, key, value, step: int):\n        \"\"\"Log a scalar value to W&B.\n\n        Args:\n            key: Metric or field name.\n            value: Value to log.\n            step: Step index (typically epoch).\n        \"\"\"\n        self.log_summary(\"current_epoch\", step)\n        if self.resume and step < self.wandb_init_step:\n            return\n        self.run.log({key: value}, step=step)\n\n    def log_summary(self, key, value):\n        \"\"\"Write a summary value to W&B.\n\n        Args:\n            key: Metric or field name.\n            value: Summary value to store.\n        \"\"\"\n        self.run.summary[key] = value\n"
  },
  {
    "path": "nnunetv2/training/loss/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/loss/compound_losses.py",
    "content": "import torch\nfrom nnunetv2.training.loss.dice import SoftDiceLoss, MemoryEfficientSoftDiceLoss\nfrom nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss, TopKLoss\nfrom nnunetv2.utilities.helpers import softmax_helper_dim1\nfrom torch import nn\n\n\nclass DC_and_CE_loss(nn.Module):\n    def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None,\n                 dice_class=SoftDiceLoss):\n        \"\"\"\n        Weights for CE and Dice do not need to sum to one. You can set whatever you want.\n        :param soft_dice_kwargs:\n        :param ce_kwargs:\n        :param aggregate:\n        :param square_dice:\n        :param weight_ce:\n        :param weight_dice:\n        \"\"\"\n        super(DC_and_CE_loss, self).__init__()\n        if ignore_label is not None:\n            ce_kwargs['ignore_index'] = ignore_label\n\n        self.weight_dice = weight_dice\n        self.weight_ce = weight_ce\n        self.ignore_label = ignore_label\n\n        self.ce = RobustCrossEntropyLoss(**ce_kwargs)\n        self.dc = dice_class(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs)\n\n    def forward(self, net_output: torch.Tensor, target: torch.Tensor):\n        \"\"\"\n        target must be b, c, x, y(, z) with c=1\n        :param net_output:\n        :param target:\n        :return:\n        \"\"\"\n        if self.ignore_label is not None:\n            assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \\\n                                         '(DC_and_CE_loss)'\n            mask = target != self.ignore_label\n            # remove ignore label from target, replace with one of the known labels. It doesn't matter because we\n            # ignore gradients in those areas anyway\n            target_dice = torch.where(mask, target, 0)\n            num_fg = mask.sum()\n        else:\n            target_dice = target\n            mask = None\n\n        dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \\\n            if self.weight_dice != 0 else 0\n        ce_loss = self.ce(net_output, target[:, 0]) \\\n            if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0\n\n        result = self.weight_ce * ce_loss + self.weight_dice * dc_loss\n        return result\n\n\nclass DC_and_BCE_loss(nn.Module):\n    def __init__(self, bce_kwargs, soft_dice_kwargs, weight_ce=1, weight_dice=1, use_ignore_label: bool = False,\n                 dice_class=MemoryEfficientSoftDiceLoss):\n        \"\"\"\n        DO NOT APPLY NONLINEARITY IN YOUR NETWORK!\n\n        target mut be one hot encoded\n        IMPORTANT: We assume use_ignore_label is located in target[:, -1]!!!\n\n        :param soft_dice_kwargs:\n        :param bce_kwargs:\n        :param aggregate:\n        \"\"\"\n        super(DC_and_BCE_loss, self).__init__()\n        if use_ignore_label:\n            bce_kwargs['reduction'] = 'none'\n\n        self.weight_dice = weight_dice\n        self.weight_ce = weight_ce\n        self.use_ignore_label = use_ignore_label\n\n        self.ce = nn.BCEWithLogitsLoss(**bce_kwargs)\n        self.dc = dice_class(apply_nonlin=torch.sigmoid, **soft_dice_kwargs)\n\n    def forward(self, net_output: torch.Tensor, target: torch.Tensor):\n        if self.use_ignore_label:\n            # target is one hot encoded here. invert it so that it is True wherever we can compute the loss\n            if target.dtype == torch.bool:\n                mask = ~target[:, -1:]\n            else:\n                mask = (1 - target[:, -1:]).bool()\n            # remove ignore channel now that we have the mask\n            # why did we use clone in the past? Should have documented that...\n            # target_regions = torch.clone(target[:, :-1])\n            target_regions = target[:, :-1]\n        else:\n            target_regions = target\n            mask = None\n\n        dc_loss = self.dc(net_output, target_regions, loss_mask=mask)\n        target_regions = target_regions.float()\n        if mask is not None:\n            ce_loss = (self.ce(net_output, target_regions) * mask).sum() / torch.clip(mask.sum(), min=1e-8)\n        else:\n            ce_loss = self.ce(net_output, target_regions)\n        result = self.weight_ce * ce_loss + self.weight_dice * dc_loss\n        return result\n\n\nclass DC_and_topk_loss(nn.Module):\n    def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None):\n        \"\"\"\n        Weights for CE and Dice do not need to sum to one. You can set whatever you want.\n        :param soft_dice_kwargs:\n        :param ce_kwargs:\n        :param aggregate:\n        :param square_dice:\n        :param weight_ce:\n        :param weight_dice:\n        \"\"\"\n        super().__init__()\n        if ignore_label is not None:\n            ce_kwargs['ignore_index'] = ignore_label\n\n        self.weight_dice = weight_dice\n        self.weight_ce = weight_ce\n        self.ignore_label = ignore_label\n\n        self.ce = TopKLoss(**ce_kwargs)\n        self.dc = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs)\n\n    def forward(self, net_output: torch.Tensor, target: torch.Tensor):\n        \"\"\"\n        target must be b, c, x, y(, z) with c=1\n        :param net_output:\n        :param target:\n        :return:\n        \"\"\"\n        if self.ignore_label is not None:\n            assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \\\n                                         '(DC_and_CE_loss)'\n            mask = (target != self.ignore_label).bool()\n            # remove ignore label from target, replace with one of the known labels. It doesn't matter because we\n            # ignore gradients in those areas anyway\n            target_dice = torch.clone(target)\n            target_dice[target == self.ignore_label] = 0\n            num_fg = mask.sum()\n        else:\n            target_dice = target\n            mask = None\n\n        dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \\\n            if self.weight_dice != 0 else 0\n        ce_loss = self.ce(net_output, target) \\\n            if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0\n\n        result = self.weight_ce * ce_loss + self.weight_dice * dc_loss\n        return result\n"
  },
  {
    "path": "nnunetv2/training/loss/deep_supervision.py",
    "content": "from torch import nn\n\n\nclass DeepSupervisionWrapper(nn.Module):\n    def __init__(self, loss, weight_factors=None):\n        \"\"\"\n        Wraps a loss function so that it can be applied to multiple outputs. Forward accepts an arbitrary number of\n        inputs. Each input is expected to be a tuple/list. Each tuple/list must have the same length. The loss is then\n        applied to each entry like this:\n        l = w0 * loss(input0[0], input1[0], ...) +  w1 * loss(input0[1], input1[1], ...) + ...\n        If weights are None, all w will be 1.\n        \"\"\"\n        super(DeepSupervisionWrapper, self).__init__()\n        assert any([x != 0 for x in weight_factors]), \"At least one weight factor should be != 0.0\"\n        self.weight_factors = tuple(weight_factors)\n        self.loss = loss\n\n    def forward(self, *args):\n        assert all([isinstance(i, (tuple, list)) for i in args]), \\\n            f\"all args must be either tuple or list, got {[type(i) for i in args]}\"\n        # we could check for equal lengths here as well, but we really shouldn't overdo it with checks because\n        # this code is executed a lot of times!\n\n        if self.weight_factors is None:\n            weights = (1, ) * len(args[0])\n        else:\n            weights = self.weight_factors\n\n        return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0])\n"
  },
  {
    "path": "nnunetv2/training/loss/dice.py",
    "content": "from typing import Callable\n\nimport torch\nfrom nnunetv2.utilities.ddp_allgather import AllGatherGrad\nfrom torch import nn\n\n\nclass SoftDiceLoss(nn.Module):\n    def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.,\n                 ddp: bool = True, clip_tp: float = None):\n        \"\"\"\n        \"\"\"\n        super(SoftDiceLoss, self).__init__()\n\n        self.do_bg = do_bg\n        self.batch_dice = batch_dice\n        self.apply_nonlin = apply_nonlin\n        self.smooth = smooth\n        self.clip_tp = clip_tp\n        self.ddp = ddp\n\n    def forward(self, x, y, loss_mask=None):\n        shp_x = x.shape\n\n        if self.batch_dice:\n            axes = [0] + list(range(2, len(shp_x)))\n        else:\n            axes = list(range(2, len(shp_x)))\n\n        if self.apply_nonlin is not None:\n            x = self.apply_nonlin(x)\n\n        tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False)\n\n        if self.ddp and self.batch_dice:\n            tp = AllGatherGrad.apply(tp).sum(0, dtype=torch.float32)\n            fp = AllGatherGrad.apply(fp).sum(0, dtype=torch.float32)\n            fn = AllGatherGrad.apply(fn).sum(0, dtype=torch.float32)\n\n        if self.clip_tp is not None:\n            tp = torch.clip(tp, min=self.clip_tp , max=None)\n\n        nominator = 2 * tp\n        denominator = 2 * tp + fp + fn\n\n        dc = (nominator + self.smooth) / (torch.clip(denominator + self.smooth, 1e-8))\n\n        if not self.do_bg:\n            if self.batch_dice:\n                dc = dc[1:]\n            else:\n                dc = dc[:, 1:]\n        dc = dc.mean()\n\n        return -dc\n\n\nclass MemoryEfficientSoftDiceLoss(nn.Module):\n    def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.,\n                 ddp: bool = True):\n        \"\"\"\n        saves 1.6 GB on Dataset017 3d_lowres\n        \"\"\"\n        super(MemoryEfficientSoftDiceLoss, self).__init__()\n\n        self.do_bg = do_bg\n        self.batch_dice = batch_dice\n        self.apply_nonlin = apply_nonlin\n        self.smooth = smooth\n        self.ddp = ddp\n\n    def forward(self, x, y, loss_mask=None):\n        if self.apply_nonlin is not None:\n            x = self.apply_nonlin(x)\n\n        # make everything shape (b, c)\n        axes = tuple(range(2, x.ndim))\n\n        with torch.no_grad():\n            if x.ndim != y.ndim:\n                y = y.view((y.shape[0], 1, *y.shape[1:]))\n\n            if x.shape == y.shape:\n                # if this is the case then gt is probably already a one hot encoding\n                y_onehot = y.to(torch.float32)\n            else:\n                y_onehot = torch.zeros(x.shape, device=x.device, dtype=torch.float32)\n                y_onehot.scatter_(1, y.long(), 1)\n\n            if not self.do_bg:\n                y_onehot = y_onehot[:, 1:]\n\n            sum_gt = y_onehot.sum(axes, dtype=torch.float32) if loss_mask is None else (y_onehot * loss_mask).sum(axes, dtype=torch.float32)\n\n        # this one MUST be outside the with torch.no_grad(): context. Otherwise no gradients for you\n        if not self.do_bg:\n            x = x[:, 1:]\n\n        if loss_mask is None:\n            intersect = (x * y_onehot).sum(axes, dtype=torch.float32)\n            sum_pred = x.sum(axes, dtype=torch.float32)\n        else:\n            intersect = (x * y_onehot * loss_mask).sum(axes, dtype=torch.float32)\n            sum_pred = (x * loss_mask).sum(axes, dtype=torch.float32)\n\n        if self.batch_dice:\n            if self.ddp:\n                intersect = AllGatherGrad.apply(intersect).sum(0, dtype=torch.float32)\n                sum_pred = AllGatherGrad.apply(sum_pred).sum(0, dtype=torch.float32)\n                sum_gt = AllGatherGrad.apply(sum_gt).sum(0, dtype=torch.float32)\n\n            intersect = intersect.sum(0, dtype=torch.float32)\n            sum_pred = sum_pred.sum(0, dtype=torch.float32)\n            sum_gt = sum_gt.sum(0, dtype=torch.float32)\n\n        dc = (2 * intersect + self.smooth) / (sum_gt + sum_pred + float(self.smooth)).clamp_min(1e-8)\n\n        dc = dc.mean()\n        return -dc\n\n\ndef get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):\n    \"\"\"\n    net_output must be (b, c, x, y(, z)))\n    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))\n    if mask is provided it must have shape (b, 1, x, y(, z)))\n    :param net_output:\n    :param gt:\n    :param axes: can be (, ) = no summation\n    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels\n    :param square: if True then fp, tp and fn will be squared before summation\n    :return:\n    \"\"\"\n    if axes is None:\n        axes = tuple(range(2, net_output.ndim))\n\n    with torch.no_grad():\n        if net_output.ndim != gt.ndim:\n            gt = gt.view((gt.shape[0], 1, *gt.shape[1:]))\n\n        if net_output.shape == gt.shape:\n            # if this is the case then gt is probably already a one hot encoding\n            y_onehot = gt.to(torch.float32)\n        else:\n            y_onehot = torch.zeros(net_output.shape, device=net_output.device, dtype=torch.float32)\n            y_onehot.scatter_(1, gt.long(), 1)\n\n    tp = net_output * y_onehot\n    fp = net_output * (1 - y_onehot)\n    fn = (1 - net_output) * y_onehot\n    tn = (1 - net_output) * (1 - y_onehot)\n\n    if mask is not None:\n        with torch.no_grad():\n            mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for _ in range(2, tp.ndim)]))\n        tp *= mask_here\n        fp *= mask_here\n        fn *= mask_here\n        tn *= mask_here\n        # benchmark whether tiling the mask would be faster (torch.tile). It probably is for large batch sizes\n        # OK it barely makes a difference but the implementation above is a tiny bit faster + uses less vram\n        # (using nnUNetv2_train 998 3d_fullres 0)\n        # tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)\n        # fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)\n        # fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)\n        # tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1)\n\n    if square:\n        tp = tp ** 2\n        fp = fp ** 2\n        fn = fn ** 2\n        tn = tn ** 2\n\n    if len(axes) > 0:\n        tp = tp.sum(dim=axes, keepdim=False, dtype=torch.float32)\n        fp = fp.sum(dim=axes, keepdim=False, dtype=torch.float32)\n        fn = fn.sum(dim=axes, keepdim=False, dtype=torch.float32)\n        tn = tn.sum(dim=axes, keepdim=False, dtype=torch.float32)\n\n    return tp, fp, fn, tn\n\n\nif __name__ == '__main__':\n    from nnunetv2.utilities.helpers import softmax_helper_dim1\n    pred = torch.rand((2, 3, 32, 32, 32))\n    ref = torch.randint(0, 3, (2, 32, 32, 32))\n\n    dl_old = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False)\n    dl_new = MemoryEfficientSoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False)\n    res_old = dl_old(pred, ref)\n    res_new = dl_new(pred, ref)\n    print(res_old, res_new)\n"
  },
  {
    "path": "nnunetv2/training/loss/robust_ce_loss.py",
    "content": "import torch\nfrom torch import nn, Tensor\nimport numpy as np\n\n\nclass RobustCrossEntropyLoss(nn.CrossEntropyLoss):\n    \"\"\"\n    this is just a compatibility layer because my target tensor is float and has an extra dimension\n\n    input must be logits, not probabilities!\n    \"\"\"\n    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n        if target.ndim == input.ndim:\n            assert target.shape[1] == 1\n            target = target[:, 0]\n        return super().forward(input, target.long())\n\n\nclass TopKLoss(RobustCrossEntropyLoss):\n    \"\"\"\n    input must be logits, not probabilities!\n    \"\"\"\n    def __init__(self, weight=None, ignore_index: int = -100, k: float = 10, label_smoothing: float = 0):\n        self.k = k\n        super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False, label_smoothing=label_smoothing)\n\n    def forward(self, inp, target):\n        target = target[:, 0].long()\n        res = super(TopKLoss, self).forward(inp, target)\n        num_voxels = np.prod(res.shape, dtype=np.int64)\n        res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False)\n        return res.mean()\n"
  },
  {
    "path": "nnunetv2/training/lr_scheduler/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/lr_scheduler/polylr.py",
    "content": "from torch.optim.lr_scheduler import _LRScheduler\n\n\nclass PolyLRScheduler(_LRScheduler):\n    def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None):\n        self.optimizer = optimizer\n        self.initial_lr = initial_lr\n        self.max_steps = max_steps\n        self.exponent = exponent\n        self.ctr = 0\n        super().__init__(optimizer, current_step if current_step is not None else -1)\n\n    def step(self, current_step=None):\n        if current_step is None or current_step == -1:\n            current_step = self.ctr\n            self.ctr += 1\n\n        new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent\n        for param_group in self.optimizer.param_groups:\n            param_group['lr'] = new_lr\n        \n        \n        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]\n\n    def get_last_lr(self):\n        return self._last_lr"
  },
  {
    "path": "nnunetv2/training/lr_scheduler/warmup.py",
    "content": "import math\nimport warnings\nfrom typing import Optional, cast, List\n\nfrom torch import Tensor\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, _enable_get_lr_call\n\n\nclass Lin_incr_LRScheduler(_LRScheduler):\n    def __init__(self, optimizer, max_lr: float, max_steps: int, current_step: int = None):\n        self.optimizer = optimizer\n        self.max_lr = max_lr\n        self.max_steps = max_steps\n        self.ctr = 0\n        super().__init__(optimizer, current_step if current_step is not None else -1)\n\n    def step(self, current_step=None):\n        if current_step is None or current_step == -1:\n            current_step = self.ctr\n            self.ctr += 1\n\n        new_lr = self.max_lr / self.max_steps * (1 + current_step)\n        for param_group in self.optimizer.param_groups:\n            param_group[\"lr\"] = new_lr\n\n\nclass Lin_incr_offset_LRScheduler(_LRScheduler):\n    def __init__(self, optimizer, max_lr: float, max_steps: int, start_step: int, current_step: int = None):\n        self.optimizer = optimizer\n        self.max_lr = max_lr\n        self.max_steps = max_steps\n        self.start_step = start_step\n        self.ctr = 0\n        super().__init__(optimizer, current_step if current_step is not None else -1)\n\n    def step(self, current_step=None):\n        if current_step is None or current_step == -1:\n            current_step = self.ctr\n            self.ctr += 1\n\n        new_lr = self.max_lr / self.max_steps * (1 + current_step - self.start_step)\n        for param_group in self.optimizer.param_groups:\n            param_group[\"lr\"] = new_lr\n\n\nclass PolyLRScheduler_offset(_LRScheduler):\n    def __init__(\n        self,\n        optimizer,\n        initial_lr: float,\n        max_steps: int,\n        start_step: int,\n        exponent: float = 0.9,\n        current_step: int = None,\n    ):\n        self.optimizer = optimizer\n        self.initial_lr = initial_lr\n        self.max_steps = max_steps - start_step\n        self.start_step = start_step\n        self.exponent = exponent\n        self.ctr = 0\n        super().__init__(optimizer, current_step if current_step is not None else -1)\n\n    def step(self, current_step=None):\n        if current_step is None or current_step == -1:\n            current_step = self.ctr\n            self.ctr += 1\n\n        current_step = current_step - self.start_step\n        if current_step <= 0:\n            current_step = 0\n\n        new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent\n        for param_group in self.optimizer.param_groups:\n            param_group[\"lr\"] = new_lr\n\n\nclass CosineAnnealingLR_offset(CosineAnnealingLR):\n    def __init__(\n        self, optimizer: Optimizer, T_max: int, eta_min=0, last_epoch=-1, verbose=\"deprecated\", offset: int = 0\n    ):\n        self.offset = offset\n        super().__init__(\n            optimizer,\n            T_max,\n            eta_min,\n            last_epoch,\n            verbose,\n        )\n\n    def _get_closed_form_lr(self):\n        return [\n            self.eta_min\n            + (base_lr - self.eta_min)\n            * (1 + math.cos(math.pi * (self.last_epoch - self.offset) / (self.T_max - self.offset)))\n            / 2\n            for base_lr in self.base_lrs\n        ]\n\n    def step(self, epoch: Optional[int] = None):\n\n        # Raise a warning if old pattern is detected\n        # https://github.com/pytorch/pytorch/issues/20124\n        if self._step_count == 1:\n            if not hasattr(self.optimizer.step, \"_wrapped_by_lr_sched\"):\n                warnings.warn(\n                    \"Seems like `optimizer.step()` has been overridden after learning rate scheduler \"\n                    \"initialization. Please, make sure to call `optimizer.step()` before \"\n                    \"`lr_scheduler.step()`. See more details at \"\n                    \"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\",\n                    UserWarning,\n                )\n\n            # Just check if there were two first lr_scheduler.step() calls before optimizer.step()\n            elif not getattr(self.optimizer, \"_opt_called\", False):\n                warnings.warn(\n                    \"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n                    \"In PyTorch 1.1.0 and later, you should call them in the opposite order: \"\n                    \"`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this \"\n                    \"will result in PyTorch skipping the first value of the learning rate schedule. \"\n                    \"See more details at \"\n                    \"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\",\n                    UserWarning,\n                )\n        self._step_count += 1\n\n        with _enable_get_lr_call(self):\n            if epoch is None:\n                self.last_epoch += 1\n            else:\n                self.last_epoch = epoch\n            values = cast(List[float], self._get_closed_form_lr())\n\n        for i, data in enumerate(zip(self.optimizer.param_groups, values)):\n            param_group, lr = data\n            if isinstance(param_group[\"lr\"], Tensor):\n                lr_val = lr.item() if isinstance(lr, Tensor) else lr  # type: ignore[attr-defined]\n                param_group[\"lr\"].fill_(lr_val)\n            else:\n                param_group[\"lr\"] = lr\n\n        self._last_lr: List[float] = [group[\"lr\"] for group in self.optimizer.param_groups]"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py",
    "content": "import inspect\nimport multiprocessing\nimport os\nimport shutil\nimport sys\nimport warnings\nfrom copy import deepcopy\nfrom datetime import datetime\nfrom time import time, sleep\nfrom typing import Tuple, Union, List\n\nimport numpy as np\nimport torch\nfrom batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter\nfrom batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter\nfrom batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter\nfrom batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p\nfrom batchgeneratorsv2.helpers.scalar_type import RandomScalar\nfrom batchgeneratorsv2.transforms.base.basic_transform import BasicTransform\nfrom batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform\nfrom batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast\nfrom batchgeneratorsv2.transforms.intensity.gamma import GammaTransform\nfrom batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform\nfrom batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform\nfrom batchgeneratorsv2.transforms.nnunet.remove_connected_components import \\\n    RemoveRandomConnectedComponentFromOneHotEncodingTransform\nfrom batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform\nfrom batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform\nfrom batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform\nfrom batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform\nfrom batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform\nfrom batchgeneratorsv2.transforms.utils.compose import ComposeTransforms\nfrom batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform\nfrom batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform\nfrom batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform\nfrom batchgeneratorsv2.transforms.utils.random import RandomTransform\nfrom batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform\nfrom batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform\nfrom torch import autocast, nn\nfrom torch import distributed as dist\nfrom torch._dynamo import OptimizedModule\nfrom torch.cuda import device_count\nfrom torch import GradScaler\nfrom torch.nn.parallel import DistributedDataParallel as DDP\n\nfrom nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes\nfrom nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder\nfrom nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save\nfrom nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\nfrom nnunetv2.inference.sliding_window_prediction import compute_gaussian\nfrom nnunetv2.paths import nnUNet_preprocessed, nnUNet_results\nfrom nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size\nfrom nnunetv2.training.dataloading.nnunet_dataset import infer_dataset_class\nfrom nnunetv2.training.dataloading.data_loader import nnUNetDataLoader\nfrom nnunetv2.training.logging.nnunet_logger import MetaLogger\nfrom nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss\nfrom nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper\nfrom nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss\nfrom nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler\nfrom nnunetv2.utilities.collate_outputs import collate_outputs\nfrom nnunetv2.utilities.crossval_split import generate_crossval_split\nfrom nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA\nfrom nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy\nfrom nnunetv2.utilities.get_network_from_plans import get_network_from_plans\nfrom nnunetv2.utilities.helpers import empty_cache, dummy_context\nfrom nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels\nfrom nnunetv2.utilities.plans_handling.plans_handler import PlansManager\n\n\nclass nnUNetTrainer(object):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        # From https://grugbrain.dev/. Worth a read ya big brains ;-)\n\n        # apex predator of grug is complexity\n        # complexity bad\n        # say again:\n        # complexity very bad\n        # you say now:\n        # complexity very, very bad\n        # given choice between complexity or one on one against t-rex, grug take t-rex: at least grug see t-rex\n        # complexity is spirit demon that enter codebase through well-meaning but ultimately very clubbable non grug-brain developers and project managers who not fear complexity spirit demon or even know about sometime\n        # one day code base understandable and grug can get work done, everything good!\n        # next day impossible: complexity demon spirit has entered code and very dangerous situation!\n\n        # OK OK I am guilty. But I tried.\n        # https://www.osnews.com/images/comics/wtfm.jpg\n        # https://i.pinimg.com/originals/26/b2/50/26b250a738ea4abc7a5af4d42ad93af0.jpg\n\n        self.is_ddp = dist.is_available() and dist.is_initialized()\n        self.local_rank = 0 if not self.is_ddp else dist.get_rank()\n\n        self.device = device\n\n        # print what device we are using\n        if self.is_ddp:  # implicitly it's clear that we use cuda in this case\n            print(f\"I am local rank {self.local_rank}. {device_count()} GPUs are available. The world size is \"\n                  f\"{dist.get_world_size()}.\"\n                  f\"Setting device to {self.device}\")\n            self.device = torch.device(type='cuda', index=self.local_rank)\n        else:\n            if self.device.type == 'cuda':\n                # we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X\n                self.device = torch.device(type='cuda', index=0)\n            print(f\"Using device: {self.device}\")\n\n        # loading and saving this class for continuing from checkpoint should not happen based on pickling. This\n        # would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we\n        # need. So let's save the init args\n        self.my_init_kwargs = {}\n        for k in inspect.signature(self.__init__).parameters.keys():\n            self.my_init_kwargs[k] = locals()[k]\n\n        ###  Saving all the init args into class variables for later access\n        continue_training = plans.pop(\"continue_training\")\n        logger_config = {\"plans\": plans, \"configuration\": configuration, \"fold\": fold, \"dataset\": dataset_json}\n        self.plans_manager = PlansManager(plans)\n        self.configuration_manager = self.plans_manager.get_configuration(configuration)\n        self.configuration_name = configuration\n        self.dataset_json = dataset_json\n        self.fold = fold\n\n        ### Setting all the folder names. We need to make sure things don't crash in case we are just running\n        # inference and some of the folders may not be defined!\n        self.preprocessed_dataset_folder_base = join(nnUNet_preprocessed, self.plans_manager.dataset_name) \\\n            if nnUNet_preprocessed is not None else None\n        self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name,\n                                       self.__class__.__name__ + '__' + self.plans_manager.plans_name + \"__\" + configuration) \\\n            if nnUNet_results is not None else None\n        self.output_folder = join(self.output_folder_base, f'fold_{fold}')\n\n        self.preprocessed_dataset_folder = join(self.preprocessed_dataset_folder_base,\n                                                self.configuration_manager.data_identifier)\n        self.dataset_class = None  # -> initialize\n        # unlike the previous nnunet folder_with_segs_from_previous_stage is now part of the plans. For now it has to\n        # be a different configuration in the same plans\n        # IMPORTANT! the mapping must be bijective, so lowres must point to fullres and vice versa (using\n        # \"previous_stage\" and \"next_stage\"). Otherwise it won't work!\n        self.is_cascaded = self.configuration_manager.previous_stage_name is not None\n        self.folder_with_segs_from_previous_stage = \\\n            join(nnUNet_results, self.plans_manager.dataset_name,\n                 self.__class__.__name__ + '__' + self.plans_manager.plans_name + \"__\" +\n                 self.configuration_manager.previous_stage_name, 'predicted_next_stage', self.configuration_name) \\\n                if self.is_cascaded else None\n\n        ### Some hyperparameters for you to fiddle with\n        self.initial_lr = 1e-2\n        self.weight_decay = 3e-5\n        self.oversample_foreground_percent = 0.33\n        self.probabilistic_oversampling = False\n        self.num_iterations_per_epoch = 250\n        self.num_val_iterations_per_epoch = 50\n        self.num_epochs = 1000\n        self.current_epoch = 0\n        self.enable_deep_supervision = True\n\n        ### Dealing with labels/regions\n        self.label_manager = self.plans_manager.get_label_manager(dataset_json)\n        # labels can either be a list of int (regular training) or a list of tuples of int (region-based training)\n        # needed for predictions. We do sigmoid in case of (overlapping) regions\n\n        self.num_input_channels = None  # -> self.initialize()\n        self.network = None  # -> self.build_network_architecture()\n        self.optimizer = self.lr_scheduler = None  # -> self.initialize\n        self.grad_scaler = GradScaler(\"cuda\") if self.device.type == 'cuda' else None\n        self.loss = None  # -> self.initialize\n\n        ### Simple logging. Don't take that away from me!\n        # initialize log file. This is just our log for the print statements etc. Not to be confused with lightning\n        # logging\n        timestamp = datetime.now()\n        maybe_mkdir_p(self.output_folder)\n        self.log_file = join(self.output_folder, \"training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt\" %\n                             (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,\n                              timestamp.second))\n        self.logger = MetaLogger(self.output_folder, continue_training)\n        self.logger.update_config(logger_config)\n\n        ### placeholders\n        self.dataloader_train = self.dataloader_val = None  # see on_train_start\n\n        ### initializing stuff for remembering things and such\n        self._best_ema = None\n\n        ### inference things\n        self.inference_allowed_mirroring_axes = None  # this variable is set in\n        # self.configure_rotation_dummyDA_mirroring_and_inital_patch_size and will be saved in checkpoints\n\n        ### checkpoint saving stuff\n        self.save_every = 50\n        self.disable_checkpointing = False\n\n        self.was_initialized = False\n\n        self.print_to_log_file(\"\\n#######################################################################\\n\"\n                               \"Please cite the following paper when using nnU-Net:\\n\"\n                               \"Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). \"\n                               \"nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. \"\n                               \"Nature methods, 18(2), 203-211.\\n\"\n                               \"#######################################################################\\n\",\n                               also_print_to_console=True, add_timestamp=False)\n\n    def initialize(self):\n        if not self.was_initialized:\n            ## DDP batch size and oversampling can differ between workers and needs adaptation\n            # we need to change the batch size in DDP because we don't use any of those distributed samplers\n            self._set_batch_size_and_oversample()\n\n            self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,\n                                                                   self.dataset_json)\n\n            self.network = self.build_network_architecture(\n                self.configuration_manager.network_arch_class_name,\n                self.configuration_manager.network_arch_init_kwargs,\n                self.configuration_manager.network_arch_init_kwargs_req_import,\n                self.num_input_channels,\n                self.label_manager.num_segmentation_heads,\n                self.enable_deep_supervision\n            ).to(self.device)\n            # compile network for free speedup\n            if self._do_i_compile():\n                self.print_to_log_file('Using torch.compile...')\n                self.network = torch.compile(self.network)\n\n            self.optimizer, self.lr_scheduler = self.configure_optimizers()\n            # if ddp, wrap in DDP wrapper\n            if self.is_ddp:\n                self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)\n                self.network = DDP(self.network, device_ids=[self.local_rank])\n\n            self.loss = self._build_loss()\n\n            self.dataset_class = infer_dataset_class(self.preprocessed_dataset_folder)\n\n            # torch 2.2.2 crashes upon compiling CE loss\n            # if self._do_i_compile():\n            #     self.loss = torch.compile(self.loss)\n            self.was_initialized = True\n\n            logger_config_hparas = {\n                \"initial_lr\": self.initial_lr,\n                \"weight_decay\": self.weight_decay,\n                \"oversample_foreground_percent\": self.oversample_foreground_percent,\n                \"probabilistic_oversampling\": self.probabilistic_oversampling,\n                \"num_iterations_per_epoch\": self.num_iterations_per_epoch,\n                \"num_val_iterations_per_epoch\": self.num_val_iterations_per_epoch,\n                \"num_epochs\": self.num_epochs,\n                \"enable_deep_supervision\": self.enable_deep_supervision,\n                \"batch_size\": self.configuration_manager.batch_size\n                }\n            self.logger.update_config({\"hparas\": logger_config_hparas})\n        else:\n            raise RuntimeError(\"You have called self.initialize even though the trainer was already initialized. \"\n                               \"That should not happen.\")\n\n    def _do_i_compile(self):\n        # new default: compile is enabled!\n\n        # compile does not work on mps\n        if self.device == torch.device('mps'):\n            if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'):\n                self.print_to_log_file(\"INFO: torch.compile disabled because of unsupported mps device\")\n            return False\n\n        # CPU compile crashes for 2D models. Not sure if we even want to support CPU compile!? Better disable\n        if self.device == torch.device('cpu'):\n            if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'):\n                self.print_to_log_file(\"INFO: torch.compile disabled because device is CPU\")\n            return False\n\n        # default torch.compile doesn't work on windows because there are apparently no triton wheels for it\n        # https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2\n        if os.name == 'nt':\n            if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'):\n                self.print_to_log_file(\"INFO: torch.compile disabled because Windows is not natively supported. If \"\n                                       \"you know what you are doing, check https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2\")\n            return False\n\n        if 'nnUNet_compile' not in os.environ.keys():\n            return True\n        else:\n            return os.environ['nnUNet_compile'].lower() in ('true', '1', 't')\n\n    def _save_debug_information(self):\n        # saving some debug information\n        if self.local_rank == 0:\n            dct = {}\n            for k in self.__dir__():\n                if not k.startswith(\"__\"):\n                    if not callable(getattr(self, k)) or k in ['loss', ]:\n                        dct[k] = str(getattr(self, k))\n                    elif k in ['network', ]:\n                        dct[k] = str(getattr(self, k).__class__.__name__)\n                    else:\n                        # print(k)\n                        pass\n                if k in ['dataloader_train', 'dataloader_val']:\n                    dl = getattr(self, k)\n                    if hasattr(dl, 'generator'):\n                        dct[k + '.generator'] = str(dl.generator)\n                        if hasattr(dl.generator, 'transforms'):\n                            try:\n                                dct[k + '.generator.transforms'] = str(dl.generator.transforms)\n                            except Exception as e:\n                                dct[k + '.generator.transforms'] = f\"Could not stringify generator.transforms: {type(e).__name__}: {e}\"\n                    if hasattr(dl, 'num_processes'):\n                        dct[k + '.num_processes'] = str(dl.num_processes)\n                    if hasattr(dl, 'transform'):\n                        dct[k + '.transform'] = str(dl.transform)\n            import subprocess\n            hostname = subprocess.getoutput(['hostname'])\n            dct['hostname'] = hostname\n            torch_version = torch.__version__\n            if self.device.type == 'cuda':\n                gpu_name = torch.cuda.get_device_name()\n                dct['gpu_name'] = gpu_name\n                cudnn_version = torch.backends.cudnn.version()\n            else:\n                cudnn_version = 'None'\n            dct['device'] = str(self.device)\n            dct['torch_version'] = torch_version\n            dct['cudnn_version'] = cudnn_version\n            save_json(dct, join(self.output_folder, \"debug.json\"))\n\n    @staticmethod\n    def build_network_architecture(architecture_class_name: str,\n                                   arch_init_kwargs: dict,\n                                   arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],\n                                   num_input_channels: int,\n                                   num_output_channels: int,\n                                   enable_deep_supervision: bool = True) -> nn.Module:\n        \"\"\"\n        This is where you build the architecture according to the plans. There is no obligation to use\n        get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what\n        you want. Even ignore the plans and just return something static (as long as it can process the requested\n        patch size)\n        but don't bug us with your bugs arising from fiddling with this :-P\n        This is the function that is called in inference as well! This is needed so that all network architecture\n        variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for\n        training, so if you change the network architecture during training by deriving a new trainer class then\n        inference will know about it).\n\n        If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet:\n        > label_manager = plans_manager.get_label_manager(dataset_json)\n        > label_manager.num_segmentation_heads\n        (why so complicated? -> We can have either classical training (classes) or regions. If we have regions,\n        the number of outputs is != the number of classes. Also there is the ignore label for which no output\n        should be generated. label_manager takes care of all that for you.)\n\n        \"\"\"\n        return get_network_from_plans(\n            architecture_class_name,\n            arch_init_kwargs,\n            arch_init_kwargs_req_import,\n            num_input_channels,\n            num_output_channels,\n            allow_init=True,\n            deep_supervision=enable_deep_supervision)\n\n    def _get_deep_supervision_scales(self):\n        if self.enable_deep_supervision:\n            deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack(\n                self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1]\n        else:\n            deep_supervision_scales = None  # for train and val_transforms\n        return deep_supervision_scales\n\n    def _set_batch_size_and_oversample(self):\n        if not self.is_ddp:\n            # set batch size to what the plan says, leave oversample untouched\n            self.batch_size = self.configuration_manager.batch_size\n        else:\n            # batch size is distributed over DDP workers and we need to change oversample_percent for each worker\n\n            world_size = dist.get_world_size()\n            my_rank = dist.get_rank()\n\n            global_batch_size = self.configuration_manager.batch_size\n            assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \\\n                                                    'GPUs... Duh.'\n\n            batch_size_per_GPU = [global_batch_size // world_size] * world_size\n            batch_size_per_GPU = [batch_size_per_GPU[i] + 1\n                                  if (batch_size_per_GPU[i] * world_size + i) < global_batch_size\n                                  else batch_size_per_GPU[i]\n                                  for i in range(len(batch_size_per_GPU))]\n            assert sum(batch_size_per_GPU) == global_batch_size\n\n            sample_id_low = 0 if my_rank == 0 else np.sum(batch_size_per_GPU[:my_rank])\n            sample_id_high = np.sum(batch_size_per_GPU[:my_rank + 1])\n\n            # This is how oversampling is determined in DataLoader\n            # round(self.batch_size * (1 - self.oversample_foreground_percent))\n            # We need to use the same scheme here because an oversample of 0.33 with a batch size of 2 will be rounded\n            # to an oversample of 0.5 (1 sample random, one oversampled). This may get lost if we just numerically\n            # compute oversample\n            oversample = [True if not i < round(global_batch_size * (1 - self.oversample_foreground_percent)) else False\n                          for i in range(global_batch_size)]\n\n            if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent):\n                oversample_percent = 0.0\n            elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent):\n                oversample_percent = 1.0\n            else:\n                oversample_percent = sum(oversample[sample_id_low:sample_id_high]) / batch_size_per_GPU[my_rank]\n\n            print(\"worker\", my_rank, \"oversample\", oversample_percent)\n            print(\"worker\", my_rank, \"batch_size\", batch_size_per_GPU[my_rank])\n\n            self.batch_size = batch_size_per_GPU[my_rank]\n            self.oversample_foreground_percent = oversample_percent\n\n    def _build_loss(self):\n        if self.label_manager.has_regions:\n            loss = DC_and_BCE_loss({},\n                                   {'batch_dice': self.configuration_manager.batch_dice,\n                                    'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp},\n                                   use_ignore_label=self.label_manager.ignore_label is not None,\n                                   dice_class=MemoryEfficientSoftDiceLoss)\n        else:\n            loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,\n                                   'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,\n                                  ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss)\n\n        if self._do_i_compile():\n            loss.dc = torch.compile(loss.dc)\n\n        # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases\n        # this gives higher resolution outputs more weight in the loss\n\n        if self.enable_deep_supervision:\n            deep_supervision_scales = self._get_deep_supervision_scales()\n            weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])\n            if self.is_ddp and not self._do_i_compile():\n                # very strange and stupid interaction. DDP crashes and complains about unused parameters due to\n                # weights[-1] = 0. Interestingly this crash doesn't happen with torch.compile enabled. Strange stuff.\n                # Anywho, the simple fix is to set a very low weight to this.\n                weights[-1] = 1e-6\n            else:\n                weights[-1] = 0\n\n            # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1\n            weights = weights / weights.sum()\n            # now wrap the loss\n            loss = DeepSupervisionWrapper(loss, weights)\n\n        return loss\n\n    def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):\n        \"\"\"\n        This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it.\n        \"\"\"\n        patch_size = self.configuration_manager.patch_size\n        dim = len(patch_size)\n        # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation)\n        if dim == 2:\n            do_dummy_2d_data_aug = False\n            # todo revisit this parametrization\n            if max(patch_size) / min(patch_size) > 1.5:\n                rotation_for_DA = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)\n            else:\n                rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)\n            mirror_axes = (0, 1)\n        elif dim == 3:\n            # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad\n            # order of the axes is determined by spacing, not image size\n            do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD\n            if do_dummy_2d_data_aug:\n                # why do we rotate 180 deg here all the time? We should also restrict it\n                rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)\n            else:\n                rotation_for_DA = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)\n            mirror_axes = (0, 1, 2)\n        else:\n            raise RuntimeError()\n\n        # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the\n        #  old nnunet for now)\n        initial_patch_size = get_patch_size(patch_size[-dim:],\n                                            rotation_for_DA,\n                                            rotation_for_DA,\n                                            rotation_for_DA,\n                                            (0.85, 1.25))\n        if do_dummy_2d_data_aug:\n            initial_patch_size[0] = patch_size[0]\n\n        self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')\n        self.inference_allowed_mirroring_axes = mirror_axes\n\n        return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes\n\n    def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True):\n        if self.local_rank == 0:\n            timestamp = time()\n            dt_object = datetime.fromtimestamp(timestamp)\n\n            if add_timestamp:\n                args = (f\"{dt_object}:\", *args)\n\n            successful = False\n            max_attempts = 5\n            ctr = 0\n            while not successful and ctr < max_attempts:\n                try:\n                    with open(self.log_file, 'a+') as f:\n                        for a in args:\n                            f.write(str(a))\n                            f.write(\" \")\n                        f.write(\"\\n\")\n                    successful = True\n                except IOError:\n                    print(f\"{datetime.fromtimestamp(timestamp)}: failed to log: \", sys.exc_info())\n                    sleep(0.5)\n                    ctr += 1\n            if also_print_to_console:\n                print(*args)\n        elif also_print_to_console:\n            print(*args)\n\n    def print_plans(self):\n        if self.local_rank == 0:\n            dct = deepcopy(self.plans_manager.plans)\n            del dct['configurations']\n            self.print_to_log_file(f\"\\nThis is the configuration used by this \"\n                                   f\"training:\\nConfiguration name: {self.configuration_name}\\n\",\n                                   self.configuration_manager, '\\n', add_timestamp=False)\n            self.print_to_log_file('These are the global plan.json settings:\\n', dct, '\\n', add_timestamp=False)\n\n    def configure_optimizers(self):\n        optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,\n                                    momentum=0.99, nesterov=True)\n        lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)\n        return optimizer, lr_scheduler\n\n    def plot_network_architecture(self):\n        if self._do_i_compile():\n            self.print_to_log_file(\"Unable to plot network architecture: nnUNet_compile is enabled!\")\n            return\n\n        if self.local_rank == 0:\n            try:\n                # raise NotImplementedError('hiddenlayer no longer works and we do not have a viable alternative :-(')\n                # pip install git+https://github.com/saugatkandel/hiddenlayer.git\n\n                # from torchviz import make_dot\n                # # not viable.\n                # make_dot(tuple(self.network(torch.rand((1, self.num_input_channels,\n                #                                         *self.configuration_manager.patch_size),\n                #                                        device=self.device)))).render(\n                #     join(self.output_folder, \"network_architecture.pdf\"), format='pdf')\n                # self.optimizer.zero_grad()\n\n                # broken.\n\n                import hiddenlayer as hl\n                g = hl.build_graph(self.network,\n                                   torch.rand((1, self.num_input_channels,\n                                               *self.configuration_manager.patch_size),\n                                              device=self.device),\n                                   transforms=None)\n                g.save(join(self.output_folder, \"network_architecture.pdf\"))\n                del g\n            except Exception as e:\n                self.print_to_log_file(\"Unable to plot network architecture:\")\n                self.print_to_log_file(e)\n\n                # self.print_to_log_file(\"\\nprinting the network instead:\\n\")\n                # self.print_to_log_file(self.network)\n                # self.print_to_log_file(\"\\n\")\n            finally:\n                empty_cache(self.device)\n\n    def do_split(self):\n        \"\"\"\n        The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded,\n        so always the same) and save it as splits_final.json file in the preprocessed data directory.\n        Sometimes you may want to create your own split for various reasons. For this you will need to create your own\n        splits_final.json file. If this file is present, nnU-Net is going to use it and whatever splits are defined in\n        it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3)\n        and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to\n        use a random 80:20 data split.\n        :return:\n        \"\"\"\n        if self.dataset_class is None:\n            self.dataset_class = infer_dataset_class(self.preprocessed_dataset_folder)\n\n        if self.fold == \"all\":\n            # if fold==all then we use all images for training and validation\n            case_identifiers = self.dataset_class.get_identifiers(self.preprocessed_dataset_folder)\n            tr_keys = case_identifiers\n            val_keys = tr_keys\n        else:\n            splits_file = join(self.preprocessed_dataset_folder_base, \"splits_final.json\")\n            dataset = self.dataset_class(self.preprocessed_dataset_folder,\n                                         identifiers=None,\n                                         folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)\n            # if the split file does not exist we need to create it\n            if not isfile(splits_file):\n                self.print_to_log_file(\"Creating new 5-fold cross-validation split...\")\n                all_keys_sorted = list(np.sort(list(dataset.identifiers)))\n                splits = generate_crossval_split(all_keys_sorted, seed=12345, n_splits=5)\n                save_json(splits, splits_file)\n\n            else:\n                self.print_to_log_file(\"Using splits from existing split file:\", splits_file)\n                splits = load_json(splits_file)\n                self.print_to_log_file(f\"The split file contains {len(splits)} splits.\")\n\n            self.print_to_log_file(\"Desired fold for training: %d\" % self.fold)\n            if self.fold < len(splits):\n                tr_keys = splits[self.fold]['train']\n                val_keys = splits[self.fold]['val']\n                self.print_to_log_file(\"This split has %d training and %d validation cases.\"\n                                       % (len(tr_keys), len(val_keys)))\n            else:\n                self.print_to_log_file(\"INFO: You requested fold %d for training but splits \"\n                                       \"contain only %d folds. I am now creating a \"\n                                       \"random (but seeded) 80:20 split!\" % (self.fold, len(splits)))\n                # if we request a fold that is not in the split file, create a random 80:20 split\n                rnd = np.random.RandomState(seed=12345 + self.fold)\n                keys = np.sort(list(dataset.identifiers))\n                idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False)\n                idx_val = [i for i in range(len(keys)) if i not in idx_tr]\n                tr_keys = [keys[i] for i in idx_tr]\n                val_keys = [keys[i] for i in idx_val]\n                self.print_to_log_file(\"This random 80:20 split has %d training and %d validation cases.\"\n                                       % (len(tr_keys), len(val_keys)))\n            if any([i in val_keys for i in tr_keys]):\n                self.print_to_log_file('WARNING: Some validation cases are also in the training set. Please check the '\n                                       'splits.json or ignore if this is intentional.')\n        return tr_keys, val_keys\n\n    def get_tr_and_val_datasets(self):\n        # create dataset split\n        tr_keys, val_keys = self.do_split()\n\n        # load the datasets for training and validation. Note that we always draw random samples so we really don't\n        # care about distributing training cases across GPUs.\n        dataset_tr = self.dataset_class(self.preprocessed_dataset_folder, tr_keys,\n                                        folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)\n        dataset_val = self.dataset_class(self.preprocessed_dataset_folder, val_keys,\n                                         folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)\n        return dataset_tr, dataset_val\n\n    def get_dataloaders(self):\n        if self.dataset_class is None:\n            self.dataset_class = infer_dataset_class(self.preprocessed_dataset_folder)\n\n        # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether\n        # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be\n        patch_size = self.configuration_manager.patch_size\n\n        # needed for deep supervision: how much do we need to downscale the segmentation targets for the different\n        # outputs?\n        deep_supervision_scales = self._get_deep_supervision_scales()\n\n        (\n            rotation_for_DA,\n            do_dummy_2d_data_aug,\n            initial_patch_size,\n            mirror_axes,\n        ) = self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()\n\n        # training pipeline\n        tr_transforms = self.get_training_transforms(\n            patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,\n            use_mask_for_norm=self.configuration_manager.use_mask_for_norm,\n            is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels,\n            regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,\n            ignore_label=self.label_manager.ignore_label)\n\n        # validation pipeline\n        val_transforms = self.get_validation_transforms(deep_supervision_scales,\n                                                        is_cascaded=self.is_cascaded,\n                                                        foreground_labels=self.label_manager.foreground_labels,\n                                                        regions=self.label_manager.foreground_regions if\n                                                        self.label_manager.has_regions else None,\n                                                        ignore_label=self.label_manager.ignore_label)\n\n        dataset_tr, dataset_val = self.get_tr_and_val_datasets()\n        dl_tr = nnUNetDataLoader(dataset_tr, self.batch_size,\n                                 initial_patch_size,\n                                 self.configuration_manager.patch_size,\n                                 self.label_manager,\n                                 oversample_foreground_percent=self.oversample_foreground_percent,\n                                 sampling_probabilities=None, pad_sides=None, transforms=tr_transforms,\n                                 probabilistic_oversampling=self.probabilistic_oversampling)\n        dl_val = nnUNetDataLoader(dataset_val, self.batch_size,\n                                  self.configuration_manager.patch_size,\n                                  self.configuration_manager.patch_size,\n                                  self.label_manager,\n                                  oversample_foreground_percent=self.oversample_foreground_percent,\n                                  sampling_probabilities=None, pad_sides=None, transforms=val_transforms,\n                                  probabilistic_oversampling=self.probabilistic_oversampling)\n\n        allowed_num_processes = get_allowed_n_proc_DA()\n        if allowed_num_processes == 0:\n            mt_gen_train = SingleThreadedAugmenter(dl_tr, None)\n            mt_gen_val = SingleThreadedAugmenter(dl_val, None)\n        else:\n            mt_gen_train = NonDetMultiThreadedAugmenter(data_loader=dl_tr, transform=None,\n                                                        num_processes=allowed_num_processes,\n                                                        num_cached=max(6, allowed_num_processes // 2), seeds=None,\n                                                        pin_memory=self.device.type == 'cuda', wait_time=0.002)\n            mt_gen_val = NonDetMultiThreadedAugmenter(data_loader=dl_val,\n                                                      transform=None, num_processes=max(1, allowed_num_processes // 2),\n                                                      num_cached=max(3, allowed_num_processes // 4), seeds=None,\n                                                      pin_memory=self.device.type == 'cuda',\n                                                      wait_time=0.002)\n        # # let's get this party started\n        _ = next(mt_gen_train)\n        _ = next(mt_gen_val)\n        return mt_gen_train, mt_gen_val\n\n    @staticmethod\n    def get_training_transforms(\n            patch_size: Union[np.ndarray, Tuple[int]],\n            rotation_for_DA: RandomScalar,\n            deep_supervision_scales: Union[List, Tuple, None],\n            mirror_axes: Tuple[int, ...],\n            do_dummy_2d_data_aug: bool,\n            use_mask_for_norm: List[bool] = None,\n            is_cascaded: bool = False,\n            foreground_labels: Union[Tuple[int, ...], List[int]] = None,\n            regions: List[Union[List[int], Tuple[int, ...], int]] = None,\n            ignore_label: int = None,\n    ) -> BasicTransform:\n        transforms = []\n        if do_dummy_2d_data_aug:\n            ignore_axes = (0,)\n            transforms.append(Convert3DTo2DTransform())\n            patch_size_spatial = patch_size[1:]\n        else:\n            patch_size_spatial = patch_size\n            ignore_axes = None\n        transforms.append(\n            SpatialTransform(\n                patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0,\n                p_rotation=0.2,\n                rotation=rotation_for_DA, p_scaling=0.2, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1,\n                bg_style_seg_sampling=False  # , mode_seg='nearest'\n            )\n        )\n\n        if do_dummy_2d_data_aug:\n            transforms.append(Convert2DTo3DTransform())\n\n        transforms.append(RandomTransform(\n            GaussianNoiseTransform(\n                noise_variance=(0, 0.1),\n                p_per_channel=1,\n                synchronize_channels=True\n            ), apply_probability=0.1\n        ))\n        transforms.append(RandomTransform(\n            GaussianBlurTransform(\n                blur_sigma=(0.5, 1.),\n                synchronize_channels=False,\n                synchronize_axes=False,\n                p_per_channel=0.5, benchmark=True\n            ), apply_probability=0.2\n        ))\n        transforms.append(RandomTransform(\n            MultiplicativeBrightnessTransform(\n                multiplier_range=BGContrast((0.75, 1.25)),\n                synchronize_channels=False,\n                p_per_channel=1\n            ), apply_probability=0.15\n        ))\n        transforms.append(RandomTransform(\n            ContrastTransform(\n                contrast_range=BGContrast((0.75, 1.25)),\n                preserve_range=True,\n                synchronize_channels=False,\n                p_per_channel=1\n            ), apply_probability=0.15\n        ))\n        transforms.append(RandomTransform(\n            SimulateLowResolutionTransform(\n                scale=(0.5, 1),\n                synchronize_channels=False,\n                synchronize_axes=True,\n                ignore_axes=ignore_axes,\n                allowed_channels=None,\n                p_per_channel=0.5\n            ), apply_probability=0.25\n        ))\n        transforms.append(RandomTransform(\n            GammaTransform(\n                gamma=BGContrast((0.7, 1.5)),\n                p_invert_image=1,\n                synchronize_channels=False,\n                p_per_channel=1,\n                p_retain_stats=1\n            ), apply_probability=0.1\n        ))\n        transforms.append(RandomTransform(\n            GammaTransform(\n                gamma=BGContrast((0.7, 1.5)),\n                p_invert_image=0,\n                synchronize_channels=False,\n                p_per_channel=1,\n                p_retain_stats=1\n            ), apply_probability=0.3\n        ))\n        if mirror_axes is not None and len(mirror_axes) > 0:\n            transforms.append(\n                MirrorTransform(\n                    allowed_axes=mirror_axes\n                )\n            )\n\n        if use_mask_for_norm is not None and any(use_mask_for_norm):\n            transforms.append(MaskImageTransform(\n                apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],\n                channel_idx_in_seg=0,\n                set_outside_to=0,\n            ))\n\n        transforms.append(\n            RemoveLabelTansform(-1, 0)\n        )\n        if is_cascaded:\n            assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations'\n            transforms.append(\n                MoveSegAsOneHotToDataTransform(\n                    source_channel_idx=1,\n                    all_labels=foreground_labels,\n                    remove_channel_from_source=True\n                )\n            )\n            transforms.append(\n                RandomTransform(\n                    ApplyRandomBinaryOperatorTransform(\n                        channel_idx=list(range(-len(foreground_labels), 0)),\n                        strel_size=(1, 8),\n                        p_per_label=1\n                    ), apply_probability=0.4\n                )\n            )\n            transforms.append(\n                RandomTransform(\n                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(\n                        channel_idx=list(range(-len(foreground_labels), 0)),\n                        fill_with_other_class_p=0,\n                        dont_do_if_covers_more_than_x_percent=0.15,\n                        p_per_label=1\n                    ), apply_probability=0.2\n                )\n            )\n\n        if regions is not None:\n            # the ignore label must also be converted\n            transforms.append(\n                ConvertSegmentationToRegionsTransform(\n                    regions=list(regions) + [ignore_label] if ignore_label is not None else regions,\n                    channel_in_seg=0\n                )\n            )\n\n        if deep_supervision_scales is not None:\n            transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))\n\n        return ComposeTransforms(transforms)\n\n    @staticmethod\n    def get_validation_transforms(\n            deep_supervision_scales: Union[List, Tuple, None],\n            is_cascaded: bool = False,\n            foreground_labels: Union[Tuple[int, ...], List[int]] = None,\n            regions: List[Union[List[int], Tuple[int, ...], int]] = None,\n            ignore_label: int = None,\n    ) -> BasicTransform:\n        transforms = []\n        transforms.append(\n            RemoveLabelTansform(-1, 0)\n        )\n\n        if is_cascaded:\n            transforms.append(\n                MoveSegAsOneHotToDataTransform(\n                    source_channel_idx=1,\n                    all_labels=foreground_labels,\n                    remove_channel_from_source=True\n                )\n            )\n\n        if regions is not None:\n            # the ignore label must also be converted\n            transforms.append(\n                ConvertSegmentationToRegionsTransform(\n                    regions=list(regions) + [ignore_label] if ignore_label is not None else regions,\n                    channel_in_seg=0\n                )\n            )\n\n        if deep_supervision_scales is not None:\n            transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))\n        return ComposeTransforms(transforms)\n\n    def set_deep_supervision_enabled(self, enabled: bool):\n        \"\"\"\n        This function is specific for the default architecture in nnU-Net. If you change the architecture, there are\n        chances you need to change this as well!\n        \"\"\"\n        if self.is_ddp:\n            mod = self.network.module\n        else:\n            mod = self.network\n        if isinstance(mod, OptimizedModule):\n            mod = mod._orig_mod\n\n        mod.decoder.deep_supervision = enabled\n\n    def on_train_start(self):\n        if not self.was_initialized:\n            self.initialize()\n\n        # dataloaders must be instantiated here (instead of __init__) because they need access to the training data\n        # which may not be present  when doing inference\n        self.dataloader_train, self.dataloader_val = self.get_dataloaders()\n\n        maybe_mkdir_p(self.output_folder)\n\n        # make sure deep supervision is on in the network\n        self.set_deep_supervision_enabled(self.enable_deep_supervision)\n\n        self.print_plans()\n        empty_cache(self.device)\n\n        # maybe unpack\n        if self.local_rank == 0:\n            self.dataset_class.unpack_dataset(\n                self.preprocessed_dataset_folder,\n                overwrite_existing=False,\n                num_processes=max(1, round(get_allowed_n_proc_DA() // 2)),\n                verify=True)\n\n        if self.is_ddp:\n            dist.barrier()\n\n        # copy plans and dataset.json so that they can be used for restoring everything we need for inference\n        save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False)\n        save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False)\n\n        # we don't really need the fingerprint but its still handy to have it with the others\n        shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'),\n                    join(self.output_folder_base, 'dataset_fingerprint.json'))\n\n        # produces a pdf in output folder\n        self.plot_network_architecture()\n\n        self._save_debug_information()\n\n        # print(f\"batch size: {self.batch_size}\")\n        # print(f\"oversample: {self.oversample_foreground_percent}\")\n\n    def on_train_end(self):\n        # dirty hack because on_epoch_end increments the epoch counter and this is executed afterwards.\n        # This will lead to the wrong current epoch to be stored\n        self.current_epoch -= 1\n        self.save_checkpoint(join(self.output_folder, \"checkpoint_final.pth\"))\n        self.current_epoch += 1\n\n        # now we can delete latest\n        if self.local_rank == 0 and isfile(join(self.output_folder, \"checkpoint_latest.pth\")):\n            os.remove(join(self.output_folder, \"checkpoint_latest.pth\"))\n\n        # shut down dataloaders\n        old_stdout = sys.stdout\n        with open(os.devnull, 'w') as f:\n            sys.stdout = f\n            if self.dataloader_train is not None and \\\n                    isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)):\n                self.dataloader_train._finish()\n            if self.dataloader_val is not None and \\\n                    isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)):\n                self.dataloader_val._finish()\n            sys.stdout = old_stdout\n\n        empty_cache(self.device)\n        self.print_to_log_file(\"Training done.\")\n\n    def on_train_epoch_start(self):\n        self.network.train()\n        self.lr_scheduler.step(self.current_epoch)\n        self.print_to_log_file('')\n        self.print_to_log_file(f'Epoch {self.current_epoch}')\n        self.print_to_log_file(\n            f\"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}\")\n        # lrs are the same for all workers so we don't need to gather them in case of DDP training\n        self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch)\n\n    def train_step(self, batch: dict) -> dict:\n        data = batch['data']\n        target = batch['target']\n\n        data = data.to(self.device, non_blocking=True)\n        if isinstance(target, list):\n            target = [i.to(self.device, non_blocking=True) for i in target]\n        else:\n            target = target.to(self.device, non_blocking=True)\n\n        self.optimizer.zero_grad(set_to_none=True)\n        # Autocast can be annoying\n        # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.\n        # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)\n        # So autocast will only be active if we have a cuda device.\n        with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():\n            output = self.network(data)\n            # del data\n            l = self.loss(output, target)\n\n        if self.grad_scaler is not None:\n            self.grad_scaler.scale(l).backward()\n            self.grad_scaler.unscale_(self.optimizer)\n            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)\n            self.grad_scaler.step(self.optimizer)\n            self.grad_scaler.update()\n        else:\n            l.backward()\n            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)\n            self.optimizer.step()\n        return {'loss': l.detach().cpu().numpy()}\n\n    def on_train_epoch_end(self, train_outputs: List[dict]):\n        outputs = collate_outputs(train_outputs)\n\n        if self.is_ddp:\n            losses_tr = [None for _ in range(dist.get_world_size())]\n            dist.all_gather_object(losses_tr, outputs['loss'])\n            loss_here = np.vstack(losses_tr).mean()\n        else:\n            loss_here = np.mean(outputs['loss'])\n\n        self.logger.log('train_losses', loss_here, self.current_epoch)\n\n    def on_validation_epoch_start(self):\n        self.network.eval()\n\n    def validation_step(self, batch: dict) -> dict:\n        data = batch['data']\n        target = batch['target']\n\n        data = data.to(self.device, non_blocking=True)\n        if isinstance(target, list):\n            target = [i.to(self.device, non_blocking=True) for i in target]\n        else:\n            target = target.to(self.device, non_blocking=True)\n\n        # Autocast can be annoying\n        # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.\n        # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)\n        # So autocast will only be active if we have a cuda device.\n        with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():\n            output = self.network(data)\n            del data\n            l = self.loss(output, target)\n\n        # we only need the output with the highest output resolution (if DS enabled)\n        if self.enable_deep_supervision:\n            output = output[0]\n            target = target[0]\n\n        # the following is needed for online evaluation. Fake dice (green line)\n        axes = [0] + list(range(2, output.ndim))\n\n        if self.label_manager.has_regions:\n            predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()\n        else:\n            # no need for softmax\n            output_seg = output.argmax(1)[:, None]\n            predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float16)\n            predicted_segmentation_onehot.scatter_(1, output_seg, 1)\n            del output_seg\n\n        if self.label_manager.has_ignore_label:\n            if not self.label_manager.has_regions:\n                mask = (target != self.label_manager.ignore_label).float()\n                # CAREFUL that you don't rely on target after this line!\n                target[target == self.label_manager.ignore_label] = 0\n            else:\n                if target.dtype == torch.bool:\n                    mask = ~target[:, -1:]\n                else:\n                    mask = 1 - target[:, -1:]\n                # CAREFUL that you don't rely on target after this line!\n                target = target[:, :-1]\n        else:\n            mask = None\n\n        tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)\n\n        tp_hard = tp.detach().cpu().numpy()\n        fp_hard = fp.detach().cpu().numpy()\n        fn_hard = fn.detach().cpu().numpy()\n        if not self.label_manager.has_regions:\n            # if we train with regions all segmentation heads predict some kind of foreground. In conventional\n            # (softmax training) there needs tobe one output for the background. We are not interested in the\n            # background Dice\n            # [1:] in order to remove background\n            tp_hard = tp_hard[1:]\n            fp_hard = fp_hard[1:]\n            fn_hard = fn_hard[1:]\n\n        return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}\n\n    def on_validation_epoch_end(self, val_outputs: List[dict]):\n        outputs_collated = collate_outputs(val_outputs)\n        tp = np.sum(outputs_collated['tp_hard'], 0)\n        fp = np.sum(outputs_collated['fp_hard'], 0)\n        fn = np.sum(outputs_collated['fn_hard'], 0)\n\n        if self.is_ddp:\n            world_size = dist.get_world_size()\n\n            tps = [None for _ in range(world_size)]\n            dist.all_gather_object(tps, tp)\n            tp = np.vstack([i[None] for i in tps]).sum(0)\n\n            fps = [None for _ in range(world_size)]\n            dist.all_gather_object(fps, fp)\n            fp = np.vstack([i[None] for i in fps]).sum(0)\n\n            fns = [None for _ in range(world_size)]\n            dist.all_gather_object(fns, fn)\n            fn = np.vstack([i[None] for i in fns]).sum(0)\n\n            losses_val = [None for _ in range(world_size)]\n            dist.all_gather_object(losses_val, outputs_collated['loss'])\n            loss_here = np.vstack(losses_val).mean()\n        else:\n            loss_here = np.mean(outputs_collated['loss'])\n\n        global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in zip(tp, fp, fn)]]\n        mean_fg_dice = np.nanmean(global_dc_per_class)\n        self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch)\n        self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch)\n        self.logger.log('val_losses', loss_here, self.current_epoch)\n\n    def on_epoch_start(self):\n        self.logger.log('epoch_start_timestamps', time(), self.current_epoch)\n\n    def on_epoch_end(self):\n        self.logger.log('epoch_end_timestamps', time(), self.current_epoch)\n\n        self.print_to_log_file('train_loss', np.round(self.logger.get_value('train_losses', step=-1), decimals=4))\n        self.print_to_log_file('val_loss', np.round(self.logger.get_value('val_losses', step=-1), decimals=4))\n        self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in\n                                               self.logger.get_value('dice_per_class_or_region', step=-1)])\n        self.print_to_log_file(\n            f\"Epoch time: {np.round(self.logger.get_value('epoch_end_timestamps', step=-1) - self.logger.get_value('epoch_start_timestamps', step=-1), decimals=2)} s\")\n\n        # handling periodic checkpointing\n        current_epoch = self.current_epoch\n        if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1):\n            self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth'))\n\n        # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this\n        if self._best_ema is None or self.logger.get_value('ema_fg_dice', step=-1) > self._best_ema:\n            self._best_ema = self.logger.get_value('ema_fg_dice', step=-1)\n            self.print_to_log_file(f\"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}\")\n            self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth'))\n\n        if self.local_rank == 0:\n            self.logger.plot_progress_png(self.output_folder)\n\n        self.current_epoch += 1\n\n    def save_checkpoint(self, filename: str) -> None:\n        if self.local_rank == 0:\n            if not self.disable_checkpointing:\n                if self.is_ddp:\n                    mod = self.network.module\n                else:\n                    mod = self.network\n                if isinstance(mod, OptimizedModule):\n                    mod = mod._orig_mod\n\n                checkpoint = {\n                    'network_weights': mod.state_dict(),\n                    'optimizer_state': self.optimizer.state_dict(),\n                    'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None,\n                    'logging': self.logger.get_checkpoint(),\n                    '_best_ema': self._best_ema,\n                    'current_epoch': self.current_epoch + 1,\n                    'init_args': self.my_init_kwargs,\n                    'trainer_name': self.__class__.__name__,\n                    'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes,\n                }\n                torch.save(checkpoint, filename)\n            else:\n                self.print_to_log_file('No checkpoint written, checkpointing is disabled')\n\n    def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None:\n        if not self.was_initialized:\n            self.initialize()\n\n        if isinstance(filename_or_checkpoint, str):\n            checkpoint = torch.load(filename_or_checkpoint, map_location=self.device, weights_only=False)\n        # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not\n        # match. Use heuristic to make it match\n        new_state_dict = {}\n        for k, value in checkpoint['network_weights'].items():\n            key = k\n            if key not in self.network.state_dict().keys() and key.startswith('module.'):\n                key = key[7:]\n            new_state_dict[key] = value\n\n        self.my_init_kwargs = checkpoint['init_args']\n        self.current_epoch = checkpoint['current_epoch']\n        self.logger.load_checkpoint(checkpoint['logging'])\n        self._best_ema = checkpoint['_best_ema']\n        self.inference_allowed_mirroring_axes = checkpoint[\n            'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes\n\n        # messing with state dict naming schemes. Facepalm.\n        if self.is_ddp:\n            if isinstance(self.network.module, OptimizedModule):\n                self.network.module._orig_mod.load_state_dict(new_state_dict)\n            else:\n                self.network.module.load_state_dict(new_state_dict)\n        else:\n            if isinstance(self.network, OptimizedModule):\n                self.network._orig_mod.load_state_dict(new_state_dict)\n            else:\n                self.network.load_state_dict(new_state_dict)\n        self.optimizer.load_state_dict(checkpoint['optimizer_state'])\n        if self.grad_scaler is not None:\n            if checkpoint['grad_scaler_state'] is not None:\n                self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state'])\n\n    def perform_actual_validation(self, save_probabilities: bool = False):\n        self.set_deep_supervision_enabled(False)\n        self.network.eval()\n\n        if self.is_ddp and self.batch_size == 1 and self.enable_deep_supervision and self._do_i_compile():\n            self.print_to_log_file(\"WARNING! batch size is 1 during training and torch.compile is enabled. If you \"\n                                   \"encounter crashes in validation then this is because torch.compile forgets \"\n                                   \"to trigger a recompilation of the model with deep supervision disabled. \"\n                                   \"This causes torch.flip to complain about getting a tuple as input. Just rerun the \"\n                                   \"validation with --val (exactly the same as before) and then it will work. \"\n                                   \"Why? Because --val triggers nnU-Net to ONLY run validation meaning that the first \"\n                                   \"forward pass (where compile is triggered) already has deep supervision disabled. \"\n                                   \"This is exactly what we need in perform_actual_validation\")\n\n        predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True,\n                                    perform_everything_on_device=True, device=self.device, verbose=False,\n                                    verbose_preprocessing=False, allow_tqdm=False)\n        predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None,\n                                        self.dataset_json, self.__class__.__name__,\n                                        self.inference_allowed_mirroring_axes)\n\n        with multiprocessing.get_context(\"spawn\").Pool(default_num_processes) as segmentation_export_pool:\n            worker_list = [i for i in segmentation_export_pool._pool]\n            validation_output_folder = join(self.output_folder, 'validation')\n            maybe_mkdir_p(validation_output_folder)\n\n            # we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute\n            # the validation keys across the workers.\n            _, val_keys = self.do_split()\n            if self.is_ddp:\n                last_barrier_at_idx = len(val_keys) // dist.get_world_size() - 1\n\n                val_keys = val_keys[self.local_rank:: dist.get_world_size()]\n                # we cannot just have barriers all over the place because the number of keys each GPU receives can be\n                # different\n\n            dataset_val = self.dataset_class(self.preprocessed_dataset_folder, val_keys,\n                                             folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)\n\n            next_stages = self.configuration_manager.next_stage_names\n\n            if next_stages is not None:\n                _ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages]\n\n            results = []\n\n            for i, k in enumerate(dataset_val.identifiers):\n                proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results,\n                                                           allowed_num_queued=2)\n                while not proceed:\n                    sleep(0.1)\n                    proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results,\n                                                               allowed_num_queued=2)\n\n                self.print_to_log_file(f\"predicting {k}\")\n                data, _, seg_prev, properties = dataset_val.load_case(k)\n\n                # we do [:] to convert blosc2 to numpy\n                data = data[:]\n\n                if self.is_cascaded:\n                    seg_prev = seg_prev[:]\n                    data = np.vstack((data, convert_labelmap_to_one_hot(seg_prev, self.label_manager.foreground_labels,\n                                                                        output_dtype=data.dtype)))\n                with warnings.catch_warnings():\n                    # ignore 'The given NumPy array is not writable' warning\n                    warnings.simplefilter(\"ignore\")\n                    data = torch.from_numpy(data)\n\n                self.print_to_log_file(f'{k}, shape {data.shape}, rank {self.local_rank}')\n                output_filename_truncated = join(validation_output_folder, k)\n\n                prediction = predictor.predict_sliding_window_return_logits(data)\n                prediction = prediction.cpu()\n\n                # this needs to go into background processes\n                results.append(\n                    segmentation_export_pool.starmap_async(\n                        export_prediction_from_logits, (\n                            (prediction, properties, self.configuration_manager, self.plans_manager,\n                             self.dataset_json, output_filename_truncated, save_probabilities),\n                        )\n                    )\n                )\n                # for debug purposes\n                # export_prediction_from_logits(\n                #     prediction, properties, self.configuration_manager, self.plans_manager,\n                #      self.dataset_json, output_filename_truncated, save_probabilities\n                # )\n\n                # if needed, export the softmax prediction for the next stage\n                if next_stages is not None:\n                    for n in next_stages:\n                        next_stage_config_manager = self.plans_manager.get_configuration(n)\n                        expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name,\n                                                            next_stage_config_manager.data_identifier)\n                        # next stage may have a different dataset class, do not use self.dataset_class\n                        dataset_class = infer_dataset_class(expected_preprocessed_folder)\n\n                        try:\n                            # we do this so that we can use load_case and do not have to hard code how loading training cases is implemented\n                            tmp = dataset_class(expected_preprocessed_folder, [k])\n                            d, _, _, _ = tmp.load_case(k)\n                        except FileNotFoundError:\n                            self.print_to_log_file(\n                                f\"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! \"\n                                f\"Run the preprocessing for this configuration first!\")\n                            continue\n\n                        target_shape = d.shape[1:]\n                        output_folder = join(self.output_folder_base, 'predicted_next_stage', n)\n                        output_file_truncated = join(output_folder, k)\n\n                        # resample_and_save(prediction, target_shape, output_file_truncated, self.plans_manager,\n                        #          self.configuration_manager,\n                        #          properties,\n                        #          self.dataset_json,\n                        #          default_num_processes,\n                        #          dataset_class)\n                        results.append(segmentation_export_pool.starmap_async(\n                            resample_and_save, (\n                                (prediction, target_shape, output_file_truncated, self.plans_manager,\n                                 self.configuration_manager,\n                                 properties,\n                                 self.dataset_json,\n                                 default_num_processes,\n                                 dataset_class),\n                            )\n                        ))\n                # if we don't barrier from time to time we will get nccl timeouts for large datasets. Yuck.\n                if self.is_ddp and i < last_barrier_at_idx and (i + 1) % 20 == 0:\n                    dist.barrier()\n\n            _ = [r.get() for r in results]\n\n        if self.is_ddp:\n            dist.barrier()\n\n        if self.local_rank == 0:\n            metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'),\n                                                validation_output_folder,\n                                                join(validation_output_folder, 'summary.json'),\n                                                self.plans_manager.image_reader_writer_class(),\n                                                self.dataset_json[\"file_ending\"],\n                                                self.label_manager.foreground_regions if self.label_manager.has_regions else\n                                                self.label_manager.foreground_labels,\n                                                self.label_manager.ignore_label, chill=True,\n                                                num_processes=default_num_processes * dist.get_world_size() if\n                                                self.is_ddp else default_num_processes)\n            for label in metrics[\"mean\"]:\n                self.logger.log_summary(f\"final_val/class_{label}_dice\", metrics[\"mean\"][label][\"Dice\"])\n            self.logger.log_summary(\"final_val/foreground_dice\", metrics['foreground_mean'][\"Dice\"])\n            self.print_to_log_file(\"Validation complete\", also_print_to_console=True)\n            self.print_to_log_file(\"Mean Validation Dice: \", (metrics['foreground_mean'][\"Dice\"]),\n                                   also_print_to_console=True)\n\n        self.set_deep_supervision_enabled(True)\n        compute_gaussian.cache_clear()\n\n    def run_training(self):\n        self.on_train_start()\n\n        for epoch in range(self.current_epoch, self.num_epochs):\n            self.on_epoch_start()\n\n            self.on_train_epoch_start()\n            train_outputs = []\n            for batch_id in range(self.num_iterations_per_epoch):\n                train_outputs.append(self.train_step(next(self.dataloader_train)))\n            self.on_train_epoch_end(train_outputs)\n\n            with torch.no_grad():\n                self.on_validation_epoch_start()\n                val_outputs = []\n                for batch_id in range(self.num_val_iterations_per_epoch):\n                    val_outputs.append(self.validation_step(next(self.dataloader_val)))\n                self.on_validation_epoch_end(val_outputs)\n\n            self.on_epoch_end()\n\n        self.on_train_end()\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/primus/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/primus/primus_trainers.py",
    "content": "from abc import abstractmethod\nfrom typing import List, Tuple, Union\nimport torch\nfrom torch import nn, autocast\nfrom dynamic_network_architectures.architectures.primus import Primus\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\nfrom nnunetv2.training.nnUNetTrainer.variants.lr_schedule.nnUNetTrainer_warmup import nnUNetTrainer_warmup\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom nnunetv2.training.lr_scheduler.warmup import Lin_incr_LRScheduler, PolyLRScheduler_offset\nfrom nnunetv2.utilities.helpers import empty_cache, dummy_context\n\n######################################################\n# See this paper for information on Primus!\n# Wald*, T., Roy*, S., Isensee*, F., Ulrich, C., Ziegler, S., Trofimova, D., ... & Maier-Hein, K. (2025). Primus: Enforcing attention usage for 3d medical image segmentation. arXiv preprint arXiv:2503.01835.\n# * equal contribution\n######################################################\n\nclass AbstractPrimus(nnUNetTrainer_warmup):\n    def __init__(\n        self,\n        plans: dict,\n        configuration: str,\n        fold: int,\n        dataset_json: dict,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.initial_lr = 3e-4\n        self.weight_decay = 5e-2\n        self.enable_deep_supervision = False\n\n    @abstractmethod\n    def build_network_architecture(\n        self,\n        architecture_class_name: str,\n        arch_init_kwargs: dict,\n        arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],\n        num_input_channels: int,\n        num_output_channels: int,\n        enable_deep_supervision: bool = True,\n    ) -> nn.Module:\n        raise NotImplementedError()\n\n    def configure_optimizers(self, stage: str = \"warmup_all\"):\n        assert stage in [\"warmup_all\", \"train\"]\n\n        if self.training_stage == stage:\n            return self.optimizer, self.lr_scheduler\n\n        if isinstance(self.network, DDP):\n            params = self.network.module.parameters()\n        else:\n            params = self.network.parameters()\n\n        if stage == \"warmup_all\":\n            self.print_to_log_file(\"train whole net, warmup\")\n            optimizer = torch.optim.AdamW(\n                params, self.initial_lr, weight_decay=self.weight_decay, amsgrad=False, betas=(0.9, 0.98), fused=True\n            )\n            lr_scheduler = Lin_incr_LRScheduler(optimizer, self.initial_lr, self.warmup_duration_whole_net)\n            self.print_to_log_file(f\"Initialized warmup_all optimizer and lr_scheduler at epoch {self.current_epoch}\")\n        else:\n            self.print_to_log_file(\"train whole net, default schedule\")\n            if self.training_stage == \"warmup_all\":\n                # we can keep the existing optimizer and don't need to create a new one. This will allow us to keep\n                # the accumulated momentum terms which already point in a useful driection\n                optimizer = self.optimizer\n            else:\n                optimizer = torch.optim.AdamW(\n                    params,\n                    self.initial_lr,\n                    weight_decay=self.weight_decay,\n                    amsgrad=False,\n                    betas=(0.9, 0.98),\n                    fused=True,\n                )\n            lr_scheduler = PolyLRScheduler_offset(\n                optimizer, self.initial_lr, self.num_epochs, self.warmup_duration_whole_net\n            )\n            self.print_to_log_file(f\"Initialized train optimizer and lr_scheduler at epoch {self.current_epoch}\")\n        self.training_stage = stage\n        empty_cache(self.device)\n        return optimizer, lr_scheduler\n\n    def train_step(self, batch: dict) -> dict:\n        data = batch[\"data\"]\n        target = batch[\"target\"]\n\n        data = data.to(self.device, non_blocking=True)\n        if isinstance(target, list):\n            target = [i.to(self.device, non_blocking=True) for i in target]\n        else:\n            target = target.to(self.device, non_blocking=True)\n\n        self.optimizer.zero_grad(set_to_none=True)\n        # Autocast can be annoying\n        # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.\n        # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)\n        # So autocast will only be active if we have a cuda device.\n        with autocast(self.device.type, enabled=True) if self.device.type == \"cuda\" else dummy_context():\n            output = self.network(data)\n            # del data\n            l = self.loss(output, target)\n\n        if self.grad_scaler is not None:\n            self.grad_scaler.scale(l).backward()\n            self.grad_scaler.unscale_(self.optimizer)\n            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1)\n            self.grad_scaler.step(self.optimizer)\n            self.grad_scaler.update()\n        else:\n            l.backward()\n            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1)\n            self.optimizer.step()\n        return {\"loss\": l.detach().cpu().numpy()}\n\n    def set_deep_supervision_enabled(self, enabled: bool):\n        pass\n\n\nclass nnUNet_Primus_S_Trainer(AbstractPrimus):\n\n    def build_network_architecture(\n        self,\n        architecture_class_name: str,\n        arch_init_kwargs: dict,\n        arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],\n        num_input_channels: int,\n        num_output_channels: int,\n        enable_deep_supervision: bool = True,\n    ) -> nn.Module:\n        # this architecture will crash if the patch size is not divisible by 8!\n        model = Primus(\n            num_input_channels,\n            396,\n            (8, 8, 8),\n            num_output_channels,\n            12,\n            6,\n            self.configuration_manager.patch_size,\n            drop_path_rate=0.2,\n            scale_attn_inner=True,\n            init_values=0.1,\n        )\n        return model\n\n\nclass nnUNet_Primus_B_Trainer(AbstractPrimus):\n\n    def build_network_architecture(\n        self,\n        architecture_class_name: str,\n        arch_init_kwargs: dict,\n        arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],\n        num_input_channels: int,\n        num_output_channels: int,\n        enable_deep_supervision: bool = True,\n    ) -> nn.Module:\n        # this architecture will crash if the patch size is not divisible by 8!\n        model = Primus(\n            num_input_channels,\n            792,\n            (8, 8, 8),\n            num_output_channels,\n            12,\n            12,\n            self.configuration_manager.patch_size,\n            drop_path_rate=0.2,\n            scale_attn_inner=True,\n            init_values=0.1,\n        )\n        return model\n\n\nclass nnUNet_Primus_M_Trainer(AbstractPrimus):\n\n    def build_network_architecture(\n        self,\n        architecture_class_name: str,\n        arch_init_kwargs: dict,\n        arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],\n        num_input_channels: int,\n        num_output_channels: int,\n        enable_deep_supervision: bool = True,\n    ) -> nn.Module:\n        # this architecture will crash if the patch size is not divisible by 8!\n        model = Primus(\n            num_input_channels,\n            864,\n            (8, 8, 8),\n            num_output_channels,\n            16,\n            12,\n            self.configuration_manager.patch_size,\n            drop_path_rate=0.2,\n            scale_attn_inner=True,\n            init_values=0.1,\n        )\n        return model\n\n\nclass nnUNet_Primus_M_Trainer_BS8(nnUNet_Primus_M_Trainer):\n\n    def __init__(\n        self,\n        plans: dict,\n        configuration: str,\n        fold: int,\n        dataset_json: dict,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.configuration_manager.configuration[\"batch_size\"] = 8\n\n\nclass nnUNet_Primus_M_Trainer_BS8_2e4(nnUNet_Primus_M_Trainer):\n\n    def __init__(\n        self,\n        plans: dict,\n        configuration: str,\n        fold: int,\n        dataset_json: dict,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.initial_lr = 2e-4\n        self.configuration_manager.configuration[\"batch_size\"] = 8\n\n\nclass nnUNet_Trainer_BS8(nnUNetTrainer):\n\n    def __init__(\n        self,\n        plans: dict,\n        configuration: str,\n        fold: int,\n        dataset_json: dict,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.configuration_manager.configuration[\"batch_size\"] = 8\n\n\nclass nnUNet_Primus_L_Trainer(AbstractPrimus):\n\n    def build_network_architecture(\n        self,\n        architecture_class_name: str,\n        arch_init_kwargs: dict,\n        arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],\n        num_input_channels: int,\n        num_output_channels: int,\n        enable_deep_supervision: bool = True,\n    ) -> nn.Module:\n        # this architecture will crash if the patch size is not divisible by 8!\n        model = Primus(\n            num_input_channels,\n            1056,\n            (8, 8, 8),\n            num_output_channels,\n            24,\n            16,\n            self.configuration_manager.patch_size,\n            drop_path_rate=0.2,\n            scale_attn_inner=True,\n            init_values=0.1,\n        )\n        return model\n\n\nclass _Primus_S_96_BS1(nnUNet_Primus_S_Trainer):\n    def __init__(\n        self,\n        plans: dict,\n        configuration: str,\n        fold: int,\n        dataset_json: dict,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        plans[\"configurations\"][configuration][\"patch_size\"] = (96, 96, 96)  # As per repository\n        plans[\"configurations\"][configuration][\"batch_size\"] = 1\n        super().__init__(plans, configuration, fold, dataset_json, device)\n\n\nclass _Primus_B_96_BS1(nnUNet_Primus_B_Trainer):\n    def __init__(\n        self,\n        plans: dict,\n        configuration: str,\n        fold: int,\n        dataset_json: dict,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        plans[\"configurations\"][configuration][\"patch_size\"] = (96, 96, 96)  # As per repository\n        plans[\"configurations\"][configuration][\"batch_size\"] = 1\n        super().__init__(plans, configuration, fold, dataset_json, device)\n\n\nclass _Primus_M_96_BS1(nnUNet_Primus_M_Trainer):\n    def __init__(\n        self,\n        plans: dict,\n        configuration: str,\n        fold: int,\n        dataset_json: dict,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        plans[\"configurations\"][configuration][\"patch_size\"] = (96, 96, 96)  # As per repository\n        plans[\"configurations\"][configuration][\"batch_size\"] = 1\n        super().__init__(plans, configuration, fold, dataset_json, device)\n\n\nclass _Primus_L_48_BS1(nnUNet_Primus_L_Trainer):\n    def __init__(\n        self,\n        plans: dict,\n        configuration: str,\n        fold: int,\n        dataset_json: dict,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        plans[\"configurations\"][configuration][\"patch_size\"] = (48, 48, 48)  # As per repository\n        plans[\"configurations\"][configuration][\"batch_size\"] = 1\n        super().__init__(plans, configuration, fold, dataset_json, device)"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/benchmarking/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py",
    "content": "import subprocess\n\nimport torch\nfrom batchgenerators.utilities.file_and_folder_operations import save_json, join, isfile, load_json\n\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\nfrom torch import distributed as dist\n\n\nclass nnUNetTrainerBenchmark_5epochs(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, \n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        assert self.fold == 0, \"It makes absolutely no sense to specify a certain fold. Stick with 0 so that we can parse the results.\"\n        self.disable_checkpointing = True\n        self.num_epochs = 5\n        assert torch.cuda.is_available(), \"This only works on GPU\"\n        self.crashed_with_runtime_error = False\n\n    def perform_actual_validation(self, save_probabilities: bool = False):\n        pass\n\n    def save_checkpoint(self, filename: str) -> None:\n        # do not trust people to remember that self.disable_checkpointing must be True for this trainer\n        pass\n\n    def run_training(self):\n        try:\n            super().run_training()\n        except RuntimeError:\n            self.crashed_with_runtime_error = True\n            self.on_train_end()\n\n    def on_train_end(self):\n        super().on_train_end()\n\n        if not self.is_ddp or self.local_rank == 0:\n            torch_version = torch.__version__\n            cudnn_version = torch.backends.cudnn.version()\n            gpu_name = torch.cuda.get_device_name()\n            if self.crashed_with_runtime_error:\n                fastest_epoch = 'Not enough VRAM!'\n            else:\n                epoch_times = [i - j for i, j in zip(self.logger.get_value('epoch_end_timestamps', step=None),\n                                                     self.logger.get_value('epoch_start_timestamps', step=None))]\n                fastest_epoch = min(epoch_times)\n\n            if self.is_ddp:\n                num_gpus = dist.get_world_size()\n            else:\n                num_gpus = 1\n\n            benchmark_result_file = join(self.output_folder, 'benchmark_result.json')\n            if isfile(benchmark_result_file):\n                old_results = load_json(benchmark_result_file)\n            else:\n                old_results = {}\n            # generate some unique key\n            hostname = subprocess.getoutput('hostname')\n            my_key = f\"{hostname}__{cudnn_version}__{torch_version.replace(' ', '')}__{gpu_name.replace(' ', '')}__num_gpus_{num_gpus}\"\n            old_results[my_key] = {\n                'torch_version': torch_version,\n                'cudnn_version': cudnn_version,\n                'gpu_name': gpu_name,\n                'fastest_epoch': fastest_epoch,\n                'num_gpus': num_gpus,\n                'hostname': hostname\n            }\n            save_json(old_results,\n                      join(self.output_folder, 'benchmark_result.json'))\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py",
    "content": "import torch\n\nfrom nnunetv2.training.nnUNetTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import (\n    nnUNetTrainerBenchmark_5epochs,\n)\nfrom nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels\n\n\nclass nnUNetTrainerBenchmark_5epochs_noDataLoading(nnUNetTrainerBenchmark_5epochs):\n    def __init__(\n        self,\n        plans: dict,\n        configuration: str,\n        fold: int,\n        dataset_json: dict,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self._set_batch_size_and_oversample()\n        num_input_channels = determine_num_input_channels(\n            self.plans_manager, self.configuration_manager, self.dataset_json\n        )\n        patch_size = self.configuration_manager.patch_size\n        dummy_data = torch.rand((self.batch_size, num_input_channels, *patch_size), device=self.device)\n        if self.enable_deep_supervision:\n            dummy_target = [\n                torch.round(\n                    torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device)\n                    * max(self.label_manager.all_labels)\n                )\n                for k in self._get_deep_supervision_scales()\n            ]\n        else:\n            raise NotImplementedError(\"This trainer does not support deep supervision\")\n        self.dummy_batch = {\"data\": dummy_data, \"target\": dummy_target}\n\n    def get_dataloaders(self):\n        return None, None\n\n    def run_training(self):\n        try:\n            self.on_train_start()\n\n            for epoch in range(self.current_epoch, self.num_epochs):\n                self.on_epoch_start()\n\n                self.on_train_epoch_start()\n                train_outputs = []\n                for batch_id in range(self.num_iterations_per_epoch):\n                    train_outputs.append(self.train_step(self.dummy_batch))\n                self.on_train_epoch_end(train_outputs)\n\n                with torch.no_grad():\n                    self.on_validation_epoch_start()\n                    val_outputs = []\n                    for batch_id in range(self.num_val_iterations_per_epoch):\n                        val_outputs.append(self.validation_step(self.dummy_batch))\n                    self.on_validation_epoch_end(val_outputs)\n\n                self.on_epoch_end()\n\n            self.on_train_end()\n        except RuntimeError:\n            self.crashed_with_runtime_error = True\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/competitions/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/competitions/aortaseg24.py",
    "content": "from nnunetv2.training.nnUNetTrainer.variants.data_augmentation.nnUNetTrainerNoMirroring import nnUNetTrainer_onlyMirror01\nfrom nnunetv2.training.nnUNetTrainer.variants.data_augmentation.nnUNetTrainerDA5 import nnUNetTrainerDA5\n\nclass nnUNetTrainer_onlyMirror01_DA5(nnUNetTrainer_onlyMirror01, nnUNetTrainerDA5):\n    pass\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/data_augmentation/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py",
    "content": "from typing import List, Union, Tuple\n\nimport numpy as np\nimport torch\nfrom batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter\nfrom batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter\nfrom batchgenerators.transforms.abstract_transforms import AbstractTransform\nfrom batchgenerators.transforms.abstract_transforms import Compose\nfrom batchgenerators.transforms.color_transforms import BrightnessTransform, ContrastAugmentationTransform, \\\n    GammaTransform\nfrom batchgenerators.transforms.local_transforms import BrightnessGradientAdditiveTransform, LocalGammaTransform\nfrom batchgenerators.transforms.noise_transforms import MedianFilterTransform, GaussianBlurTransform, \\\n    GaussianNoiseTransform, BlankRectangleTransform, SharpeningTransform\nfrom batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform\nfrom batchgenerators.transforms.spatial_transforms import SpatialTransform, Rot90Transform, TransposeAxesTransform, \\\n    MirrorTransform\nfrom batchgenerators.transforms.utility_transforms import OneOfTransform, RemoveLabelTransform, RenameTransform, \\\n    NumpyToTensor\nfrom batchgeneratorsv2.helpers.scalar_type import RandomScalar\nfrom torch import autocast\n\nfrom nnunetv2.configuration import ANISO_THRESHOLD\nfrom nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size\nfrom nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \\\n    ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform\nfrom nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \\\n    DownsampleSegForDSTransform2\nfrom nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform\nfrom nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \\\n    ConvertSegmentationToRegionsTransform\nfrom nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \\\n    Convert2DTo3DTransform\nfrom nnunetv2.training.dataloading.data_loader import nnUNetDataLoader\nfrom nnunetv2.training.loss.dice import get_tp_fp_fn_tn\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\nfrom nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA\nfrom nnunetv2.utilities.helpers import dummy_context\n\n\nclass TensorToNumpy(AbstractTransform):\n    def __init__(self, keys=None, cast_to=None):\n        \"\"\"\n        Converts torch tensors to numpy ndarrays.\n\n        :param keys: specify keys to be converted. If None then all tensor values will be converted.\n                     Can be a key (string) or a list/tuple of keys.\n        :param cast_to: optional numpy dtype as string, e.g. 'float32', 'float16', 'int64', 'bool'\n        \"\"\"\n        if keys is not None and not isinstance(keys, (list, tuple)):\n            keys = [keys]\n        self.keys = keys\n        self.cast_to = cast_to\n\n    def cast(self, array: np.ndarray):\n        if self.cast_to is not None:\n            try:\n                array = array.astype(self.cast_to, copy=False)\n            except TypeError:\n                raise ValueError(f\"Unknown value for cast_to: {self.cast_to}\")\n        return array\n\n    def _to_numpy(self, tensor):\n        import torch\n\n        if isinstance(tensor, torch.Tensor):\n            # Important: detach + move to CPU before numpy conversion\n            array = tensor.detach().cpu().numpy()\n            return self.cast(array)\n        return tensor\n\n    def __call__(self, **data_dict):\n        import torch\n\n        if self.keys is None:\n            for key, val in data_dict.items():\n                if isinstance(val, torch.Tensor):\n                    data_dict[key] = self._to_numpy(val)\n                elif isinstance(val, (list, tuple)) and all(isinstance(i, torch.Tensor) for i in val):\n                    data_dict[key] = [self._to_numpy(i) for i in val]\n        else:\n            for key in self.keys:\n                val = data_dict[key]\n                if isinstance(val, torch.Tensor):\n                    data_dict[key] = self._to_numpy(val)\n                elif isinstance(val, (list, tuple)) and all(isinstance(i, torch.Tensor) for i in val):\n                    data_dict[key] = [self._to_numpy(i) for i in val]\n\n        return data_dict\n\n\nclass nnUNetTrainerDA5(nnUNetTrainer):\n    def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):\n        patch_size = self.configuration_manager.patch_size\n        dim = len(patch_size)\n        # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation)\n        if dim == 2:\n            do_dummy_2d_data_aug = False\n            # todo revisit this parametrization\n            if max(patch_size) / min(patch_size) > 1.5:\n                rotation_for_DA = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)\n            else:\n                rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)\n            mirror_axes = (0, 1)\n        elif dim == 3:\n            # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad\n            # order of the axes is determined by spacing, not image size\n            do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD\n            if do_dummy_2d_data_aug:\n                # why do we rotate 180 deg here all the time? We should also restrict it\n                rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)\n            else:\n                rotation_for_DA = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)\n            mirror_axes = (0, 1, 2)\n        else:\n            raise RuntimeError()\n\n        # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the\n        #  old nnunet for now)\n        initial_patch_size = get_patch_size(patch_size[-dim:],\n                                            rotation_for_DA,\n                                            rotation_for_DA,\n                                            rotation_for_DA,\n                                            (0.7, 1.43))\n        if do_dummy_2d_data_aug:\n            initial_patch_size[0] = patch_size[0]\n\n        self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')\n        self.inference_allowed_mirroring_axes = mirror_axes\n\n        return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes\n\n    @staticmethod\n    def get_training_transforms(\n            patch_size: Union[np.ndarray, Tuple[int]],\n            rotation_for_DA: RandomScalar,\n            deep_supervision_scales: Union[List, Tuple, None],\n            mirror_axes: Tuple[int, ...],\n            do_dummy_2d_data_aug: bool,\n            use_mask_for_norm: List[bool] = None,\n            is_cascaded: bool = False,\n            foreground_labels: Union[Tuple[int, ...], List[int]] = None,\n            regions: List[Union[List[int], Tuple[int, ...], int]] = None,\n            ignore_label: int = None,\n    ) -> AbstractTransform:\n        matching_axes = np.array([sum([i == j for j in patch_size]) for i in patch_size])\n        valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0])\n\n        tr_transforms = []\n        tr_transforms.append(TensorToNumpy())\n        tr_transforms.append(RenameTransform('target', 'seg', True))\n\n        if do_dummy_2d_data_aug:\n            ignore_axes = (0,)\n            tr_transforms.append(Convert3DTo2DTransform())\n            patch_size_spatial = patch_size[1:]\n        else:\n            patch_size_spatial = patch_size\n            ignore_axes = None\n\n        tr_transforms.append(\n            SpatialTransform(\n                patch_size_spatial,\n                patch_center_dist_from_border=None,\n                do_elastic_deform=False,\n                do_rotation=True,\n                angle_x=rotation_for_DA,\n                angle_y=rotation_for_DA,\n                angle_z=rotation_for_DA,\n                p_rot_per_axis=0.5,\n                do_scale=True,\n                scale=(0.7, 1.43),\n                border_mode_data=\"constant\",\n                border_cval_data=0,\n                order_data=3,\n                border_mode_seg=\"constant\",\n                border_cval_seg=-1,\n                order_seg=1,\n                random_crop=False,\n                p_el_per_sample=0.2,\n                p_scale_per_sample=0.2,\n                p_rot_per_sample=0.4,\n                independent_scale_for_each_axis=True,\n            )\n        )\n\n        if do_dummy_2d_data_aug:\n            tr_transforms.append(Convert2DTo3DTransform())\n\n        if np.any(matching_axes > 1):\n            tr_transforms.append(\n                Rot90Transform(\n                    (0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5\n                ),\n            )\n\n        if np.any(matching_axes > 1):\n            tr_transforms.append(\n                TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5)\n            )\n\n        tr_transforms.append(OneOfTransform([\n            MedianFilterTransform(\n                (2, 8),\n                same_for_each_channel=False,\n                p_per_sample=0.2,\n                p_per_channel=0.5\n            ),\n            GaussianBlurTransform((0.3, 1.5),\n                                  different_sigma_per_channel=True,\n                                  p_per_sample=0.2,\n                                  p_per_channel=0.5)\n        ]))\n\n        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))\n\n        tr_transforms.append(BrightnessTransform(0,\n                                                 0.5,\n                                                 per_channel=True,\n                                                 p_per_sample=0.1,\n                                                 p_per_channel=0.5\n                                                 )\n                             )\n\n        tr_transforms.append(OneOfTransform(\n            [\n                ContrastAugmentationTransform(\n                    contrast_range=(0.5, 2),\n                    preserve_range=True,\n                    per_channel=True,\n                    data_key='data',\n                    p_per_sample=0.2,\n                    p_per_channel=0.5\n                ),\n                ContrastAugmentationTransform(\n                    contrast_range=(0.5, 2),\n                    preserve_range=False,\n                    per_channel=True,\n                    data_key='data',\n                    p_per_sample=0.2,\n                    p_per_channel=0.5\n                ),\n            ]\n        ))\n\n        tr_transforms.append(\n            SimulateLowResolutionTransform(zoom_range=(0.25, 1),\n                                           per_channel=True,\n                                           p_per_channel=0.5,\n                                           order_downsample=0,\n                                           order_upsample=3,\n                                           p_per_sample=0.15,\n                                           ignore_axes=ignore_axes\n                                           )\n        )\n\n        tr_transforms.append(\n            GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))\n        tr_transforms.append(\n            GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))\n\n        if mirror_axes is not None and len(mirror_axes) > 0:\n            tr_transforms.append(MirrorTransform(mirror_axes))\n\n        tr_transforms.append(\n            BlankRectangleTransform([[max(1, p // 10), p // 3] for p in patch_size],\n                                    rectangle_value=np.mean,\n                                    num_rectangles=(1, 5),\n                                    force_square=False,\n                                    p_per_sample=0.4,\n                                    p_per_channel=0.5\n                                    )\n        )\n\n        tr_transforms.append(\n            BrightnessGradientAdditiveTransform(\n                _brightnessadditive_localgamma_transform_scale,\n                (-0.5, 1.5),\n                max_strength=_brightness_gradient_additive_max_strength,\n                mean_centered=False,\n                same_for_all_channels=False,\n                p_per_sample=0.3,\n                p_per_channel=0.5\n            )\n        )\n\n        tr_transforms.append(\n            LocalGammaTransform(\n                _brightnessadditive_localgamma_transform_scale,\n                (-0.5, 1.5),\n                _local_gamma_gamma,\n                same_for_all_channels=False,\n                p_per_sample=0.3,\n                p_per_channel=0.5\n            )\n        )\n\n        tr_transforms.append(\n            SharpeningTransform(\n                strength=(0.1, 1),\n                same_for_each_channel=False,\n                p_per_sample=0.2,\n                p_per_channel=0.5\n            )\n        )\n\n        if use_mask_for_norm is not None and any(use_mask_for_norm):\n            tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],\n                                               mask_idx_in_seg=0, set_outside_to=0))\n\n        tr_transforms.append(RemoveLabelTransform(-1, 0))\n\n        if is_cascaded:\n            if ignore_label is not None:\n                raise NotImplementedError('ignore label not yet supported in cascade')\n            assert foreground_labels is not None, 'We need all_labels for cascade augmentations'\n            use_labels = [i for i in foreground_labels if i != 0]\n            tr_transforms.append(MoveSegAsOneHotToData(1, use_labels, 'seg', 'data'))\n            tr_transforms.append(ApplyRandomBinaryOperatorTransform(\n                channel_idx=list(range(-len(use_labels), 0)),\n                p_per_sample=0.4,\n                key=\"data\",\n                strel_size=(1, 8),\n                p_per_label=1))\n            tr_transforms.append(\n                RemoveRandomConnectedComponentFromOneHotEncodingTransform(\n                    channel_idx=list(range(-len(use_labels), 0)),\n                    key=\"data\",\n                    p_per_sample=0.2,\n                    fill_with_other_class_p=0,\n                    dont_do_if_covers_more_than_x_percent=0.15))\n\n        tr_transforms.append(RenameTransform('seg', 'target', True))\n\n        if regions is not None:\n            # the ignore label must also be converted\n            tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]\n                                                                       if ignore_label is not None else regions,\n                                                                       'target', 'target'))\n\n        if deep_supervision_scales is not None:\n            tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',\n                                                              output_key='target'))\n        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))\n        tr_transforms = Compose(tr_transforms)\n        return tr_transforms\n\n    @staticmethod\n    def get_validation_transforms(\n            deep_supervision_scales: Union[List, Tuple, None],\n            is_cascaded: bool = False,\n            foreground_labels: Union[Tuple[int, ...], List[int]] = None,\n            regions: List[Union[List[int], Tuple[int, ...], int]] = None,\n            ignore_label: int = None,\n    ) -> AbstractTransform:\n        val_transforms = []\n        val_transforms.append(TensorToNumpy())\n\n        val_transforms.append(RenameTransform('target', 'seg', True))\n        val_transforms.append(RemoveLabelTransform(-1, 0))\n\n        if is_cascaded:\n            val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data'))\n\n        val_transforms.append(RenameTransform('seg', 'target', True))\n\n        if regions is not None:\n            # the ignore label must also be converted\n            val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]\n                                                                        if ignore_label is not None else regions,\n                                                                        'target', 'target'))\n\n        if deep_supervision_scales is not None:\n            val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',\n                                                               output_key='target'))\n\n        val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))\n        val_transforms = Compose(val_transforms)\n        return val_transforms\n\n    def get_dataloaders(self):\n        # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether\n        # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be\n        patch_size = self.configuration_manager.patch_size\n\n        # needed for deep supervision: how much do we need to downscale the segmentation targets for the different\n        # outputs?\n        deep_supervision_scales = self._get_deep_supervision_scales()\n\n        (\n            rotation_for_DA,\n            do_dummy_2d_data_aug,\n            initial_patch_size,\n            mirror_axes,\n        ) = self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()\n\n        # training pipeline\n        tr_transforms = self.get_training_transforms(\n            patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,\n            use_mask_for_norm=self.configuration_manager.use_mask_for_norm,\n            is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels,\n            regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,\n            ignore_label=self.label_manager.ignore_label)\n\n        # validation pipeline\n        val_transforms = self.get_validation_transforms(deep_supervision_scales,\n                                                        is_cascaded=self.is_cascaded,\n                                                        foreground_labels=self.label_manager.foreground_labels,\n                                                        regions=self.label_manager.foreground_regions if\n                                                        self.label_manager.has_regions else None,\n                                                        ignore_label=self.label_manager.ignore_label)\n\n        dataset_tr, dataset_val = self.get_tr_and_val_datasets()\n\n        # we set transforms=None because this trainer still uses batchgenerators which expects transforms to be passed to\n        dl_tr = nnUNetDataLoader(dataset_tr, self.batch_size,\n                                 initial_patch_size,\n                                 self.configuration_manager.patch_size,\n                                 self.label_manager,\n                                 oversample_foreground_percent=self.oversample_foreground_percent,\n                                 sampling_probabilities=None, pad_sides=None, transforms=None,\n                                 probabilistic_oversampling=self.probabilistic_oversampling)\n        dl_val = nnUNetDataLoader(dataset_val, self.batch_size,\n                                  self.configuration_manager.patch_size,\n                                  self.configuration_manager.patch_size,\n                                  self.label_manager,\n                                  oversample_foreground_percent=self.oversample_foreground_percent,\n                                  sampling_probabilities=None, pad_sides=None, transforms=None,\n                                  probabilistic_oversampling=self.probabilistic_oversampling)\n\n        allowed_num_processes = get_allowed_n_proc_DA()\n        if allowed_num_processes == 0:\n            mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)\n            mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)\n        else:\n            mt_gen_train = NonDetMultiThreadedAugmenter(data_loader=dl_tr, transform=tr_transforms,\n                                                        num_processes=allowed_num_processes, num_cached=6, seeds=None,\n                                                        pin_memory=self.device.type == 'cuda', wait_time=0.02)\n            mt_gen_val = NonDetMultiThreadedAugmenter(data_loader=dl_val,\n                                                      transform=val_transforms,\n                                                      num_processes=max(1, allowed_num_processes // 2),\n                                                      num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda',\n                                                      wait_time=0.02)\n        # # let's get this party started\n        _ = next(mt_gen_train)\n        _ = next(mt_gen_val)\n        return mt_gen_train, mt_gen_val\n\n    def validation_step(self, batch: dict) -> dict:\n        data = batch['data']\n        target = batch['target']\n\n        data = data.to(self.device, non_blocking=True)\n        if isinstance(target, list):\n            target = [i.to(self.device, non_blocking=True) for i in target]\n        else:\n            target = target.to(self.device, non_blocking=True)\n\n        # Autocast can be annoying\n        # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.\n        # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)\n        # So autocast will only be active if we have a cuda device.\n        with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():\n            output = self.network(data)\n            del data\n            l = self.loss(output, target)\n\n        # we only need the output with the highest output resolution (if DS enabled)\n        if self.enable_deep_supervision:\n            output = output[0]\n            target = target[0]\n\n        # the following is needed for online evaluation. Fake dice (green line)\n        axes = [0] + list(range(2, output.ndim))\n\n        if self.label_manager.has_regions:\n            predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()\n        else:\n            # no need for softmax\n            output_seg = output.argmax(1)[:, None]\n            predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float16)\n            predicted_segmentation_onehot.scatter_(1, output_seg, 1)\n            del output_seg\n\n        if self.label_manager.has_ignore_label:\n            if not self.label_manager.has_regions:\n                mask = target != self.label_manager.ignore_label\n                # CAREFUL that you don't rely on target after this line!\n                target[target == self.label_manager.ignore_label] = 0\n            else:\n                if target.dtype == torch.bool:\n                    mask = ~target[:, -1:]\n                else:\n                    mask = (1 - target[:, -1:]).bool()\n                # CAREFUL that you don't rely on target after this line!\n                target = target[:, :-1].bool()\n        else:\n            mask = None\n\n        tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)\n\n        tp_hard = tp.detach().cpu().numpy()\n        fp_hard = fp.detach().cpu().numpy()\n        fn_hard = fn.detach().cpu().numpy()\n        if not self.label_manager.has_regions:\n            # if we train with regions all segmentation heads predict some kind of foreground. In conventional\n            # (softmax training) there needs tobe one output for the background. We are not interested in the\n            # background Dice\n            # [1:] in order to remove background\n            tp_hard = tp_hard[1:]\n            fp_hard = fp_hard[1:]\n            fn_hard = fn_hard[1:]\n\n        return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}\n\n\nclass nnUNetTrainerDA5ord0(nnUNetTrainerDA5):\n    @staticmethod\n    def get_training_transforms(\n            patch_size: Union[np.ndarray, Tuple[int]],\n            rotation_for_DA: RandomScalar,\n            deep_supervision_scales: Union[List, Tuple, None],\n            mirror_axes: Tuple[int, ...],\n            do_dummy_2d_data_aug: bool,\n            use_mask_for_norm: List[bool] = None,\n            is_cascaded: bool = False,\n            foreground_labels: Union[Tuple[int, ...], List[int]] = None,\n            regions: List[Union[List[int], Tuple[int, ...], int]] = None,\n            ignore_label: int = None,\n    ) -> AbstractTransform:\n        matching_axes = np.array([sum([i == j for j in patch_size]) for i in patch_size])\n        valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0])\n\n        tr_transforms = []\n        tr_transforms.append(RenameTransform('target', 'seg', True))\n\n        if do_dummy_2d_data_aug:\n            ignore_axes = (0,)\n            tr_transforms.append(Convert3DTo2DTransform())\n            patch_size_spatial = patch_size[1:]\n        else:\n            patch_size_spatial = patch_size\n            ignore_axes = None\n\n        tr_transforms.append(\n            SpatialTransform(\n                patch_size_spatial,\n                patch_center_dist_from_border=None,\n                do_elastic_deform=False,\n                do_rotation=True,\n                angle_x=rotation_for_DA,\n                angle_y=rotation_for_DA,\n                angle_z=rotation_for_DA,\n                p_rot_per_axis=0.5,\n                do_scale=True,\n                scale=(0.7, 1.43),\n                border_mode_data=\"constant\",\n                border_cval_data=0,\n                order_data=0,\n                border_mode_seg=\"constant\",\n                border_cval_seg=-1,\n                order_seg=0,\n                random_crop=False,\n                p_el_per_sample=0.2,\n                p_scale_per_sample=0.2,\n                p_rot_per_sample=0.4,\n                independent_scale_for_each_axis=True,\n            )\n        )\n\n        if do_dummy_2d_data_aug:\n            tr_transforms.append(Convert2DTo3DTransform())\n\n        if np.any(matching_axes > 1):\n            tr_transforms.append(\n                Rot90Transform(\n                    (0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5\n                ),\n            )\n\n        if np.any(matching_axes > 1):\n            tr_transforms.append(\n                TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5)\n            )\n\n        tr_transforms.append(OneOfTransform([\n            MedianFilterTransform(\n                (2, 8),\n                same_for_each_channel=False,\n                p_per_sample=0.2,\n                p_per_channel=0.5\n            ),\n            GaussianBlurTransform((0.3, 1.5),\n                                  different_sigma_per_channel=True,\n                                  p_per_sample=0.2,\n                                  p_per_channel=0.5)\n        ]))\n\n        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))\n\n        tr_transforms.append(BrightnessTransform(0,\n                                                 0.5,\n                                                 per_channel=True,\n                                                 p_per_sample=0.1,\n                                                 p_per_channel=0.5\n                                                 )\n                             )\n\n        tr_transforms.append(OneOfTransform(\n            [\n                ContrastAugmentationTransform(\n                    contrast_range=(0.5, 2),\n                    preserve_range=True,\n                    per_channel=True,\n                    data_key='data',\n                    p_per_sample=0.2,\n                    p_per_channel=0.5\n                ),\n                ContrastAugmentationTransform(\n                    contrast_range=(0.5, 2),\n                    preserve_range=False,\n                    per_channel=True,\n                    data_key='data',\n                    p_per_sample=0.2,\n                    p_per_channel=0.5\n                ),\n            ]\n        ))\n\n        tr_transforms.append(\n            SimulateLowResolutionTransform(zoom_range=(0.25, 1),\n                                           per_channel=True,\n                                           p_per_channel=0.5,\n                                           order_downsample=0,\n                                           order_upsample=3,\n                                           p_per_sample=0.15,\n                                           ignore_axes=ignore_axes\n                                           )\n        )\n\n        tr_transforms.append(\n            GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))\n        tr_transforms.append(\n            GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))\n\n        if mirror_axes is not None and len(mirror_axes) > 0:\n            tr_transforms.append(MirrorTransform(mirror_axes))\n\n        tr_transforms.append(\n            BlankRectangleTransform([[max(1, p // 10), p // 3] for p in patch_size],\n                                    rectangle_value=np.mean,\n                                    num_rectangles=(1, 5),\n                                    force_square=False,\n                                    p_per_sample=0.4,\n                                    p_per_channel=0.5\n                                    )\n        )\n\n        tr_transforms.append(\n            BrightnessGradientAdditiveTransform(\n                _brightnessadditive_localgamma_transform_scale,\n                (-0.5, 1.5),\n                max_strength=_brightness_gradient_additive_max_strength,\n                mean_centered=False,\n                same_for_all_channels=False,\n                p_per_sample=0.3,\n                p_per_channel=0.5\n            )\n        )\n\n        tr_transforms.append(\n            LocalGammaTransform(\n                _brightnessadditive_localgamma_transform_scale,\n                (-0.5, 1.5),\n                _local_gamma_gamma,\n                same_for_all_channels=False,\n                p_per_sample=0.3,\n                p_per_channel=0.5\n            )\n        )\n\n        tr_transforms.append(\n            SharpeningTransform(\n                strength=(0.1, 1),\n                same_for_each_channel=False,\n                p_per_sample=0.2,\n                p_per_channel=0.5\n            )\n        )\n\n        if use_mask_for_norm is not None and any(use_mask_for_norm):\n            tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],\n                                               mask_idx_in_seg=0, set_outside_to=0))\n\n        tr_transforms.append(RemoveLabelTransform(-1, 0))\n\n        if is_cascaded:\n            if ignore_label is not None:\n                raise NotImplementedError('ignore label not yet supported in cascade')\n            assert foreground_labels is not None, 'We need all_labels for cascade augmentations'\n            use_labels = [i for i in foreground_labels if i != 0]\n            tr_transforms.append(MoveSegAsOneHotToData(1, use_labels, 'seg', 'data'))\n            tr_transforms.append(ApplyRandomBinaryOperatorTransform(\n                channel_idx=list(range(-len(use_labels), 0)),\n                p_per_sample=0.4,\n                key=\"data\",\n                strel_size=(1, 8),\n                p_per_label=1))\n            tr_transforms.append(\n                RemoveRandomConnectedComponentFromOneHotEncodingTransform(\n                    channel_idx=list(range(-len(use_labels), 0)),\n                    key=\"data\",\n                    p_per_sample=0.2,\n                    fill_with_other_class_p=0,\n                    dont_do_if_covers_more_than_x_percent=0.15))\n\n        tr_transforms.append(RenameTransform('seg', 'target', True))\n\n        if regions is not None:\n            # the ignore label must also be converted\n            tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]\n                                                                       if ignore_label is not None else regions,\n                                                                       'target', 'target'))\n\n        if deep_supervision_scales is not None:\n            tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',\n                                                              output_key='target'))\n        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))\n        tr_transforms = Compose(tr_transforms)\n        return tr_transforms\n\n\ndef _brightnessadditive_localgamma_transform_scale(x, y):\n    return np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y])))\n\n\ndef _brightness_gradient_additive_max_strength(_x, _y):\n    return np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5)\n\n\ndef _local_gamma_gamma():\n    return np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4)\n\n\nclass nnUNetTrainerDA5Segord0(nnUNetTrainerDA5):\n    @staticmethod\n    def get_training_transforms(\n            patch_size: Union[np.ndarray, Tuple[int]],\n            rotation_for_DA: RandomScalar,\n            deep_supervision_scales: Union[List, Tuple, None],\n            mirror_axes: Tuple[int, ...],\n            do_dummy_2d_data_aug: bool,\n            use_mask_for_norm: List[bool] = None,\n            is_cascaded: bool = False,\n            foreground_labels: Union[Tuple[int, ...], List[int]] = None,\n            regions: List[Union[List[int], Tuple[int, ...], int]] = None,\n            ignore_label: int = None,\n    ) -> AbstractTransform:\n        matching_axes = np.array([sum([i == j for j in patch_size]) for i in patch_size])\n        valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0])\n\n        tr_transforms = []\n        tr_transforms.append(RenameTransform('target', 'seg', True))\n\n        if do_dummy_2d_data_aug:\n            ignore_axes = (0,)\n            tr_transforms.append(Convert3DTo2DTransform())\n            patch_size_spatial = patch_size[1:]\n        else:\n            patch_size_spatial = patch_size\n            ignore_axes = None\n\n        tr_transforms.append(\n            SpatialTransform(\n                patch_size_spatial,\n                patch_center_dist_from_border=None,\n                do_elastic_deform=False,\n                do_rotation=True,\n                angle_x=rotation_for_DA,\n                angle_y=rotation_for_DA,\n                angle_z=rotation_for_DA,\n                p_rot_per_axis=0.5,\n                do_scale=True,\n                scale=(0.7, 1.43),\n                border_mode_data=\"constant\",\n                border_cval_data=0,\n                order_data=3,\n                border_mode_seg=\"constant\",\n                border_cval_seg=-1,\n                order_seg=0,\n                random_crop=False,\n                p_el_per_sample=0.2,\n                p_scale_per_sample=0.2,\n                p_rot_per_sample=0.4,\n                independent_scale_for_each_axis=True,\n            )\n        )\n\n        if do_dummy_2d_data_aug:\n            tr_transforms.append(Convert2DTo3DTransform())\n\n        if np.any(matching_axes > 1):\n            tr_transforms.append(\n                Rot90Transform(\n                    (0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5\n                ),\n            )\n\n        if np.any(matching_axes > 1):\n            tr_transforms.append(\n                TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5)\n            )\n\n        tr_transforms.append(OneOfTransform([\n            MedianFilterTransform(\n                (2, 8),\n                same_for_each_channel=False,\n                p_per_sample=0.2,\n                p_per_channel=0.5\n            ),\n            GaussianBlurTransform((0.3, 1.5),\n                                  different_sigma_per_channel=True,\n                                  p_per_sample=0.2,\n                                  p_per_channel=0.5)\n        ]))\n\n        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))\n\n        tr_transforms.append(BrightnessTransform(0,\n                                                 0.5,\n                                                 per_channel=True,\n                                                 p_per_sample=0.1,\n                                                 p_per_channel=0.5\n                                                 )\n                             )\n\n        tr_transforms.append(OneOfTransform(\n            [\n                ContrastAugmentationTransform(\n                    contrast_range=(0.5, 2),\n                    preserve_range=True,\n                    per_channel=True,\n                    data_key='data',\n                    p_per_sample=0.2,\n                    p_per_channel=0.5\n                ),\n                ContrastAugmentationTransform(\n                    contrast_range=(0.5, 2),\n                    preserve_range=False,\n                    per_channel=True,\n                    data_key='data',\n                    p_per_sample=0.2,\n                    p_per_channel=0.5\n                ),\n            ]\n        ))\n\n        tr_transforms.append(\n            SimulateLowResolutionTransform(zoom_range=(0.25, 1),\n                                           per_channel=True,\n                                           p_per_channel=0.5,\n                                           order_downsample=0,\n                                           order_upsample=3,\n                                           p_per_sample=0.15,\n                                           ignore_axes=ignore_axes\n                                           )\n        )\n\n        tr_transforms.append(\n            GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))\n        tr_transforms.append(\n            GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))\n\n        if mirror_axes is not None and len(mirror_axes) > 0:\n            tr_transforms.append(MirrorTransform(mirror_axes))\n\n        tr_transforms.append(\n            BlankRectangleTransform([[max(1, p // 10), p // 3] for p in patch_size],\n                                    rectangle_value=np.mean,\n                                    num_rectangles=(1, 5),\n                                    force_square=False,\n                                    p_per_sample=0.4,\n                                    p_per_channel=0.5\n                                    )\n        )\n\n        tr_transforms.append(\n            BrightnessGradientAdditiveTransform(\n                _brightnessadditive_localgamma_transform_scale,\n                (-0.5, 1.5),\n                max_strength=_brightness_gradient_additive_max_strength,\n                mean_centered=False,\n                same_for_all_channels=False,\n                p_per_sample=0.3,\n                p_per_channel=0.5\n            )\n        )\n\n        tr_transforms.append(\n            LocalGammaTransform(\n                _brightnessadditive_localgamma_transform_scale,\n                (-0.5, 1.5),\n                _local_gamma_gamma,\n                same_for_all_channels=False,\n                p_per_sample=0.3,\n                p_per_channel=0.5\n            )\n        )\n\n        tr_transforms.append(\n            SharpeningTransform(\n                strength=(0.1, 1),\n                same_for_each_channel=False,\n                p_per_sample=0.2,\n                p_per_channel=0.5\n            )\n        )\n\n        if use_mask_for_norm is not None and any(use_mask_for_norm):\n            tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],\n                                               mask_idx_in_seg=0, set_outside_to=0))\n\n        tr_transforms.append(RemoveLabelTransform(-1, 0))\n\n        if is_cascaded:\n            if ignore_label is not None:\n                raise NotImplementedError('ignore label not yet supported in cascade')\n            assert foreground_labels is not None, 'We need all_labels for cascade augmentations'\n            use_labels = [i for i in foreground_labels if i != 0]\n            tr_transforms.append(MoveSegAsOneHotToData(1, use_labels, 'seg', 'data'))\n            tr_transforms.append(ApplyRandomBinaryOperatorTransform(\n                channel_idx=list(range(-len(use_labels), 0)),\n                p_per_sample=0.4,\n                key=\"data\",\n                strel_size=(1, 8),\n                p_per_label=1))\n            tr_transforms.append(\n                RemoveRandomConnectedComponentFromOneHotEncodingTransform(\n                    channel_idx=list(range(-len(use_labels), 0)),\n                    key=\"data\",\n                    p_per_sample=0.2,\n                    fill_with_other_class_p=0,\n                    dont_do_if_covers_more_than_x_percent=0.15))\n\n        tr_transforms.append(RenameTransform('seg', 'target', True))\n\n        if regions is not None:\n            # the ignore label must also be converted\n            tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]\n                                                                       if ignore_label is not None else regions,\n                                                                       'target', 'target'))\n\n        if deep_supervision_scales is not None:\n            tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',\n                                                              output_key='target'))\n        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))\n        tr_transforms = Compose(tr_transforms)\n        return tr_transforms\n\n\nclass nnUNetTrainerDA5_10epochs(nnUNetTrainerDA5):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 10\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py",
    "content": "from typing import Union, Tuple, List\n\nfrom batchgeneratorsv2.helpers.scalar_type import RandomScalar\nfrom batchgeneratorsv2.transforms.base.basic_transform import BasicTransform\nfrom batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform\nfrom batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast\nfrom batchgeneratorsv2.transforms.intensity.gamma import GammaTransform\nfrom batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform\nfrom batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform\nfrom batchgeneratorsv2.transforms.nnunet.remove_connected_components import \\\n    RemoveRandomConnectedComponentFromOneHotEncodingTransform\nfrom batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform\nfrom batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform\nfrom batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform\nfrom batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform\nfrom batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform\nfrom batchgeneratorsv2.transforms.utils.compose import ComposeTransforms\nfrom batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform\nfrom batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform\nfrom batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform\nfrom batchgeneratorsv2.transforms.utils.random import RandomTransform\nfrom batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform\nfrom batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform\nfrom batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter\nfrom batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter\n\nfrom nnunetv2.training.dataloading.data_loader import nnUNetDataLoader\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\nfrom nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA\nimport numpy as np\n\n\nclass nnUNetTrainer_DASegOrd0(nnUNetTrainer):\n    @staticmethod\n    def get_training_transforms(\n            patch_size: Union[np.ndarray, Tuple[int]],\n            rotation_for_DA: RandomScalar,\n            deep_supervision_scales: Union[List, Tuple, None],\n            mirror_axes: Tuple[int, ...],\n            do_dummy_2d_data_aug: bool,\n            use_mask_for_norm: List[bool] = None,\n            is_cascaded: bool = False,\n            foreground_labels: Union[Tuple[int, ...], List[int]] = None,\n            regions: List[Union[List[int], Tuple[int, ...], int]] = None,\n            ignore_label: int = None,\n    ) -> BasicTransform:\n        transforms = []\n        if do_dummy_2d_data_aug:\n            ignore_axes = (0,)\n            transforms.append(Convert3DTo2DTransform())\n            patch_size_spatial = patch_size[1:]\n        else:\n            patch_size_spatial = patch_size\n            ignore_axes = None\n        transforms.append(\n            SpatialTransform(\n                patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0,\n                p_rotation=0.2,\n                rotation=rotation_for_DA, p_scaling=0.2, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1,\n                bg_style_seg_sampling=False, mode_seg='nearest'\n            )\n        )\n\n        if do_dummy_2d_data_aug:\n            transforms.append(Convert2DTo3DTransform())\n\n        transforms.append(RandomTransform(\n            GaussianNoiseTransform(\n                noise_variance=(0, 0.1),\n                p_per_channel=1,\n                synchronize_channels=True\n            ), apply_probability=0.1\n        ))\n        transforms.append(RandomTransform(\n            GaussianBlurTransform(\n                blur_sigma=(0.5, 1.),\n                synchronize_channels=False,\n                synchronize_axes=False,\n                p_per_channel=0.5, benchmark=True\n            ), apply_probability=0.2\n        ))\n        transforms.append(RandomTransform(\n            MultiplicativeBrightnessTransform(\n                multiplier_range=BGContrast((0.75, 1.25)),\n                synchronize_channels=False,\n                p_per_channel=1\n            ), apply_probability=0.15\n        ))\n        transforms.append(RandomTransform(\n            ContrastTransform(\n                contrast_range=BGContrast((0.75, 1.25)),\n                preserve_range=True,\n                synchronize_channels=False,\n                p_per_channel=1\n            ), apply_probability=0.15\n        ))\n        transforms.append(RandomTransform(\n            SimulateLowResolutionTransform(\n                scale=(0.5, 1),\n                synchronize_channels=False,\n                synchronize_axes=True,\n                ignore_axes=ignore_axes,\n                allowed_channels=None,\n                p_per_channel=0.5\n            ), apply_probability=0.25\n        ))\n        transforms.append(RandomTransform(\n            GammaTransform(\n                gamma=BGContrast((0.7, 1.5)),\n                p_invert_image=1,\n                synchronize_channels=False,\n                p_per_channel=1,\n                p_retain_stats=1\n            ), apply_probability=0.1\n        ))\n        transforms.append(RandomTransform(\n            GammaTransform(\n                gamma=BGContrast((0.7, 1.5)),\n                p_invert_image=0,\n                synchronize_channels=False,\n                p_per_channel=1,\n                p_retain_stats=1\n            ), apply_probability=0.3\n        ))\n        if mirror_axes is not None and len(mirror_axes) > 0:\n            transforms.append(\n                MirrorTransform(\n                    allowed_axes=mirror_axes\n                )\n            )\n\n        if use_mask_for_norm is not None and any(use_mask_for_norm):\n            transforms.append(MaskImageTransform(\n                apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],\n                channel_idx_in_seg=0,\n                set_outside_to=0,\n            ))\n\n        transforms.append(\n            RemoveLabelTansform(-1, 0)\n        )\n        if is_cascaded:\n            assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations'\n            transforms.append(\n                MoveSegAsOneHotToDataTransform(\n                    source_channel_idx=1,\n                    all_labels=foreground_labels,\n                    remove_channel_from_source=True\n                )\n            )\n            transforms.append(\n                RandomTransform(\n                    ApplyRandomBinaryOperatorTransform(\n                        channel_idx=list(range(-len(foreground_labels), 0)),\n                        strel_size=(1, 8),\n                        p_per_label=1\n                    ), apply_probability=0.4\n                )\n            )\n            transforms.append(\n                RandomTransform(\n                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(\n                        channel_idx=list(range(-len(foreground_labels), 0)),\n                        fill_with_other_class_p=0,\n                        dont_do_if_covers_more_than_x_percent=0.15,\n                        p_per_label=1\n                    ), apply_probability=0.2\n                )\n            )\n\n        if regions is not None:\n            # the ignore label must also be converted\n            transforms.append(\n                ConvertSegmentationToRegionsTransform(\n                    regions=list(regions) + [ignore_label] if ignore_label is not None else regions,\n                    channel_in_seg=0\n                )\n            )\n\n        if deep_supervision_scales is not None:\n            transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))\n\n        return ComposeTransforms(transforms)\n\n\nclass nnUNetTrainer_DASegOrd0_NoMirroring(nnUNetTrainer_DASegOrd0):\n    def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):\n        rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \\\n            super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()\n        mirror_axes = None\n        self.inference_allowed_mirroring_axes = None\n        return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py",
    "content": "from typing import Union, Tuple, List\n\nimport numpy as np\nfrom batchgeneratorsv2.helpers.scalar_type import RandomScalar\nfrom batchgeneratorsv2.transforms.base.basic_transform import BasicTransform\n\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\n\n\nclass nnUNetTrainerNoDA(nnUNetTrainer):\n    @staticmethod\n    def get_training_transforms(\n            patch_size: Union[np.ndarray, Tuple[int]],\n            rotation_for_DA: RandomScalar,\n            deep_supervision_scales: Union[List, Tuple, None],\n            mirror_axes: Tuple[int, ...],\n            do_dummy_2d_data_aug: bool,\n            use_mask_for_norm: List[bool] = None,\n            is_cascaded: bool = False,\n            foreground_labels: Union[Tuple[int, ...], List[int]] = None,\n            regions: List[Union[List[int], Tuple[int, ...], int]] = None,\n            ignore_label: int = None,\n    ) -> BasicTransform:\n        return nnUNetTrainer.get_validation_transforms(deep_supervision_scales, is_cascaded, foreground_labels,\n                                                       regions, ignore_label)\n\n    def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):\n        # we need to disable mirroring here so that no mirroring will be applied in inference!\n        rotation_for_DA, do_dummy_2d_data_aug, _, _ = \\\n            super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()\n        mirror_axes = None\n        self.inference_allowed_mirroring_axes = None\n        initial_patch_size = self.configuration_manager.patch_size\n        return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes\n\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py",
    "content": "from typing import Union, Tuple, List\n\nimport numpy as np\nimport torch\nfrom batchgeneratorsv2.helpers.scalar_type import RandomScalar\nfrom batchgeneratorsv2.transforms.base.basic_transform import BasicTransform\nfrom batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform\nfrom batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast\nfrom batchgeneratorsv2.transforms.intensity.gamma import GammaTransform\nfrom batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform\nfrom batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform\nfrom batchgeneratorsv2.transforms.nnunet.remove_connected_components import \\\n    RemoveRandomConnectedComponentFromOneHotEncodingTransform\nfrom batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform\nfrom batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform\nfrom batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform\nfrom batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform\nfrom batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform\nfrom batchgeneratorsv2.transforms.utils.compose import ComposeTransforms\nfrom batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform\nfrom batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform\nfrom batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform\nfrom batchgeneratorsv2.transforms.utils.random import RandomTransform\nfrom batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform\nfrom batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\n\n\nclass nnUNetTrainerNoMirroring(nnUNetTrainer):\n    def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):\n        rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \\\n            super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()\n        mirror_axes = None\n        self.inference_allowed_mirroring_axes = None\n        return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes\n\n\nclass nnUNetTrainer_onlyMirror01(nnUNetTrainer):\n    \"\"\"\n    Only mirrors along spatial axes 0 and 1 for 3D and 0 for 2D\n    \"\"\"\n    def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):\n        rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \\\n            super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()\n        patch_size = self.configuration_manager.patch_size\n        dim = len(patch_size)\n        if dim == 2:\n            mirror_axes = (0, )\n        else:\n            mirror_axes = (0, 1)\n        self.inference_allowed_mirroring_axes = mirror_axes\n        return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes\n\n\nclass nnUNetTrainer_onlyMirror01_1500ep(nnUNetTrainer_onlyMirror01):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 1500\n\n\nclass nnUNetTrainer_onlyMirror01_DASegOrd0(nnUNetTrainer_onlyMirror01):\n    @staticmethod\n    def get_training_transforms(\n            patch_size: Union[np.ndarray, Tuple[int]],\n            rotation_for_DA: RandomScalar,\n            deep_supervision_scales: Union[List, Tuple, None],\n            mirror_axes: Tuple[int, ...],\n            do_dummy_2d_data_aug: bool,\n            use_mask_for_norm: List[bool] = None,\n            is_cascaded: bool = False,\n            foreground_labels: Union[Tuple[int, ...], List[int]] = None,\n            regions: List[Union[List[int], Tuple[int, ...], int]] = None,\n            ignore_label: int = None,\n    ) -> BasicTransform:\n        transforms = []\n        if do_dummy_2d_data_aug:\n            ignore_axes = (0,)\n            transforms.append(Convert3DTo2DTransform())\n            patch_size_spatial = patch_size[1:]\n        else:\n            patch_size_spatial = patch_size\n            ignore_axes = None\n        transforms.append(\n            SpatialTransform(\n                patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0,\n                p_rotation=0.2,\n                rotation=rotation_for_DA, p_scaling=0.2, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1,\n                bg_style_seg_sampling=False, mode_seg='nearest'\n            )\n        )\n\n        if do_dummy_2d_data_aug:\n            transforms.append(Convert2DTo3DTransform())\n\n        transforms.append(RandomTransform(\n            GaussianNoiseTransform(\n                noise_variance=(0, 0.1),\n                p_per_channel=1,\n                synchronize_channels=True\n            ), apply_probability=0.1\n        ))\n        transforms.append(RandomTransform(\n            GaussianBlurTransform(\n                blur_sigma=(0.5, 1.),\n                synchronize_channels=False,\n                synchronize_axes=False,\n                p_per_channel=0.5, benchmark=True\n            ), apply_probability=0.2\n        ))\n        transforms.append(RandomTransform(\n            MultiplicativeBrightnessTransform(\n                multiplier_range=BGContrast((0.75, 1.25)),\n                synchronize_channels=False,\n                p_per_channel=1\n            ), apply_probability=0.15\n        ))\n        transforms.append(RandomTransform(\n            ContrastTransform(\n                contrast_range=BGContrast((0.75, 1.25)),\n                preserve_range=True,\n                synchronize_channels=False,\n                p_per_channel=1\n            ), apply_probability=0.15\n        ))\n        transforms.append(RandomTransform(\n            SimulateLowResolutionTransform(\n                scale=(0.5, 1),\n                synchronize_channels=False,\n                synchronize_axes=True,\n                ignore_axes=ignore_axes,\n                allowed_channels=None,\n                p_per_channel=0.5\n            ), apply_probability=0.25\n        ))\n        transforms.append(RandomTransform(\n            GammaTransform(\n                gamma=BGContrast((0.7, 1.5)),\n                p_invert_image=1,\n                synchronize_channels=False,\n                p_per_channel=1,\n                p_retain_stats=1\n            ), apply_probability=0.1\n        ))\n        transforms.append(RandomTransform(\n            GammaTransform(\n                gamma=BGContrast((0.7, 1.5)),\n                p_invert_image=0,\n                synchronize_channels=False,\n                p_per_channel=1,\n                p_retain_stats=1\n            ), apply_probability=0.3\n        ))\n        if mirror_axes is not None and len(mirror_axes) > 0:\n            transforms.append(\n                MirrorTransform(\n                    allowed_axes=mirror_axes\n                )\n            )\n\n        if use_mask_for_norm is not None and any(use_mask_for_norm):\n            transforms.append(MaskImageTransform(\n                apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],\n                channel_idx_in_seg=0,\n                set_outside_to=0,\n            ))\n\n        transforms.append(\n            RemoveLabelTansform(-1, 0)\n        )\n        if is_cascaded:\n            assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations'\n            transforms.append(\n                MoveSegAsOneHotToDataTransform(\n                    source_channel_idx=1,\n                    all_labels=foreground_labels,\n                    remove_channel_from_source=True\n                )\n            )\n            transforms.append(\n                RandomTransform(\n                    ApplyRandomBinaryOperatorTransform(\n                        channel_idx=list(range(-len(foreground_labels), 0)),\n                        strel_size=(1, 8),\n                        p_per_label=1\n                    ), apply_probability=0.4\n                )\n            )\n            transforms.append(\n                RandomTransform(\n                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(\n                        channel_idx=list(range(-len(foreground_labels), 0)),\n                        fill_with_other_class_p=0,\n                        dont_do_if_covers_more_than_x_percent=0.15,\n                        p_per_label=1\n                    ), apply_probability=0.2\n                )\n            )\n\n        if regions is not None:\n            # the ignore label must also be converted\n            transforms.append(\n                ConvertSegmentationToRegionsTransform(\n                    regions=list(regions) + [ignore_label] if ignore_label is not None else regions,\n                    channel_in_seg=0\n                )\n            )\n\n        if deep_supervision_scales is not None:\n            transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))\n\n        return ComposeTransforms(transforms)"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainer_noDummy2DDA.py",
    "content": "from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\nimport numpy as np\n\n\nclass nnUNetTrainer_noDummy2DDA(nnUNetTrainer):\n    def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):\n        do_dummy_2d_data_aug = False\n\n        patch_size = self.configuration_manager.patch_size\n        dim = len(patch_size)\n        if dim == 2:\n            if max(patch_size) / min(patch_size) > 1.5:\n                rotation_for_DA = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)\n            else:\n                rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)\n            mirror_axes = (0, 1)\n        elif dim == 3:\n            rotation_for_DA = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)\n            mirror_axes = (0, 1, 2)\n        else:\n            raise RuntimeError()\n\n        # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the\n        #  old nnunet for now)\n        initial_patch_size = get_patch_size(patch_size[-dim:],\n                                            rotation_for_DA,\n                                            rotation_for_DA,\n                                            rotation_for_DA,\n                                            (0.85, 1.25))\n\n        self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')\n        self.inference_allowed_mirroring_axes = mirror_axes\n\n        return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/loss/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py",
    "content": "import torch\nfrom nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\nfrom nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss\nimport numpy as np\n\n\nclass nnUNetTrainerCELoss(nnUNetTrainer):\n    def _build_loss(self):\n        assert not self.label_manager.has_regions, \"regions not supported by this trainer\"\n        loss = RobustCrossEntropyLoss(\n            weight=None, ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100\n        )\n\n        # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases\n        # this gives higher resolution outputs more weight in the loss\n        if self.enable_deep_supervision:\n            deep_supervision_scales = self._get_deep_supervision_scales()\n            weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))])\n            weights[-1] = 0\n\n            # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1\n            weights = weights / weights.sum()\n            # now wrap the loss\n            loss = DeepSupervisionWrapper(loss, weights)\n        return loss\n\n\nclass nnUNetTrainerCELoss_5epochs(nnUNetTrainerCELoss):\n    def __init__(\n        self,\n        plans: dict,\n        configuration: str,\n        fold: int,\n        dataset_json: dict,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        \"\"\"used for debugging plans etc\"\"\"\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 5\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py",
    "content": "import numpy as np\nimport torch\n\nfrom nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss\nfrom nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper\nfrom nnunetv2.training.loss.dice import MemoryEfficientSoftDiceLoss\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\nfrom nnunetv2.utilities.helpers import softmax_helper_dim1\n\n\nclass nnUNetTrainerDiceLoss(nnUNetTrainer):\n    def _build_loss(self):\n        loss = MemoryEfficientSoftDiceLoss(**{'batch_dice': self.configuration_manager.batch_dice,\n                                    'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp},\n                            apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1)\n\n        if self.enable_deep_supervision:\n            deep_supervision_scales = self._get_deep_supervision_scales()\n\n            # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases\n            # this gives higher resolution outputs more weight in the loss\n            weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])\n            weights[-1] = 0\n\n            # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1\n            weights = weights / weights.sum()\n            # now wrap the loss\n            loss = DeepSupervisionWrapper(loss, weights)\n        return loss\n\n\nclass nnUNetTrainerDiceCELoss_noSmooth(nnUNetTrainer):\n    def _build_loss(self):\n        # set smooth to 0\n        if self.label_manager.has_regions:\n            loss = DC_and_BCE_loss({},\n                                   {'batch_dice': self.configuration_manager.batch_dice,\n                                    'do_bg': True, 'smooth': 0, 'ddp': self.is_ddp},\n                                   use_ignore_label=self.label_manager.ignore_label is not None,\n                                   dice_class=MemoryEfficientSoftDiceLoss)\n        else:\n            loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,\n                                   'smooth': 0, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,\n                                  ignore_label=self.label_manager.ignore_label,\n                                  dice_class=MemoryEfficientSoftDiceLoss)\n\n        if self.enable_deep_supervision:\n            deep_supervision_scales = self._get_deep_supervision_scales()\n\n            # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases\n            # this gives higher resolution outputs more weight in the loss\n            weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])\n            weights[-1] = 0\n\n            # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1\n            weights = weights / weights.sum()\n            # now wrap the loss\n            loss = DeepSupervisionWrapper(loss, weights)\n        return loss\n\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py",
    "content": "from nnunetv2.training.loss.compound_losses import DC_and_topk_loss\nfrom nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\nimport numpy as np\nfrom nnunetv2.training.loss.robust_ce_loss import TopKLoss\n\n\nclass nnUNetTrainerTopk10Loss(nnUNetTrainer):\n    def _build_loss(self):\n        assert not self.label_manager.has_regions, \"regions not supported by this trainer\"\n        loss = TopKLoss(\n            ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, k=10\n        )\n\n        if self.enable_deep_supervision:\n            deep_supervision_scales = self._get_deep_supervision_scales()\n\n            # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases\n            # this gives higher resolution outputs more weight in the loss\n            weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))])\n            weights[-1] = 0\n\n            # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1\n            weights = weights / weights.sum()\n            # now wrap the loss\n            loss = DeepSupervisionWrapper(loss, weights)\n        return loss\n\n\nclass nnUNetTrainerTopk10LossLS01(nnUNetTrainer):\n    def _build_loss(self):\n        assert not self.label_manager.has_regions, \"regions not supported by this trainer\"\n        loss = TopKLoss(\n            ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100,\n            k=10,\n            label_smoothing=0.1,\n        )\n\n        if self.enable_deep_supervision:\n            deep_supervision_scales = self._get_deep_supervision_scales()\n\n            # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases\n            # this gives higher resolution outputs more weight in the loss\n            weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))])\n            weights[-1] = 0\n\n            # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1\n            weights = weights / weights.sum()\n            # now wrap the loss\n            loss = DeepSupervisionWrapper(loss, weights)\n        return loss\n\n\nclass nnUNetTrainerDiceTopK10Loss(nnUNetTrainer):\n    def _build_loss(self):\n        assert not self.label_manager.has_regions, \"regions not supported by this trainer\"\n        loss = DC_and_topk_loss(\n            {\"batch_dice\": self.configuration_manager.batch_dice, \"smooth\": 1e-5, \"do_bg\": False, \"ddp\": self.is_ddp},\n            {\"k\": 10, \"label_smoothing\": 0.0},\n            weight_ce=1,\n            weight_dice=1,\n            ignore_label=self.label_manager.ignore_label,\n        )\n        if self.enable_deep_supervision:\n            deep_supervision_scales = self._get_deep_supervision_scales()\n\n            # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases\n            # this gives higher resolution outputs more weight in the loss\n            weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))])\n            weights[-1] = 0\n\n            # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1\n            weights = weights / weights.sum()\n            # now wrap the loss\n            loss = DeepSupervisionWrapper(loss, weights)\n        return loss\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/lr_schedule/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py",
    "content": "import torch\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\n\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\n\n\nclass nnUNetTrainerCosAnneal(nnUNetTrainer):\n    def configure_optimizers(self):\n        optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,\n                                    momentum=0.99, nesterov=True)\n        lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs)\n        return optimizer, lr_scheduler\n\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainer_warmup.py",
    "content": "from typing import Union\n\nimport torch\nfrom torch._dynamo import OptimizedModule\n\nfrom nnunetv2.training.lr_scheduler.warmup import Lin_incr_LRScheduler, PolyLRScheduler_offset\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\nfrom torch.nn.parallel import DistributedDataParallel as DDP\n\nfrom nnunetv2.utilities.helpers import empty_cache\n\n\nclass nnUNetTrainer_warmup(nnUNetTrainer):\n    \"\"\"\n    Does a warmup of the entire architecture\n    Then does normal training\n    \"\"\"\n\n    def __init__(\n        self,\n        plans: dict,\n        configuration: str,\n        fold: int,\n        dataset_json: dict,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        #### hyperparameters for warmup\n        self.warmup_duration_whole_net = 50  # lin increase whole network\n        self.num_epochs = 1000\n        self.training_stage = None  # 'warmup_all', 'train'\n\n    def configure_optimizers(self, stage: str = \"warmup_all\"):\n        assert stage in [\"warmup_all\", \"train\"]\n\n        if self.training_stage == stage:\n            return self.optimizer, self.lr_scheduler\n\n        if isinstance(self.network, DDP):\n            params = self.network.module.parameters()\n        else:\n            params = self.network.parameters()\n\n        if stage == \"warmup_all\":\n            self.print_to_log_file(\"train whole net, warmup\")\n            optimizer = torch.optim.SGD(\n                params, self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True\n            )\n            lr_scheduler = Lin_incr_LRScheduler(optimizer, self.initial_lr, self.warmup_duration_whole_net)\n            self.print_to_log_file(f\"Initialized warmup_all optimizer and lr_scheduler at epoch {self.current_epoch}\")\n        else:\n            self.print_to_log_file(\"train whole net, default schedule\")\n            if self.training_stage == \"warmup_all\":\n                # we can keep the existing optimizer and don't need to create a new one. This will allow us to keep\n                # the accumulated momentum terms which already point in a useful driection\n                optimizer = self.optimizer\n            else:\n                optimizer = torch.optim.SGD(\n                    params, self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True\n                )\n            lr_scheduler = PolyLRScheduler_offset(\n                optimizer, self.initial_lr, self.num_epochs, self.warmup_duration_whole_net\n            )\n            self.print_to_log_file(f\"Initialized train optimizer and lr_scheduler at epoch {self.current_epoch}\")\n        self.training_stage = stage\n        empty_cache(self.device)\n        return optimizer, lr_scheduler\n\n    def on_train_epoch_start(self):\n        if self.current_epoch == 0:\n            self.optimizer, self.lr_scheduler = self.configure_optimizers(\"warmup_all\")\n        elif self.current_epoch == self.warmup_duration_whole_net:\n            self.optimizer, self.lr_scheduler = self.configure_optimizers(\"train\")\n\n        super().on_train_epoch_start()\n\n    def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None:\n        \"\"\"\n        We need to overwrite that entire function because we need to fiddle the correct optimizer in between\n        loading the checkpoint and applying the optimizer states. Yuck.\n        \"\"\"\n        if not self.was_initialized:\n            self.initialize()\n\n        if isinstance(filename_or_checkpoint, str):\n            checkpoint = torch.load(filename_or_checkpoint, map_location=self.device)\n        # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not\n        # match. Use heuristic to make it match\n        new_state_dict = {}\n        for k, value in checkpoint[\"network_weights\"].items():\n            key = k\n            if key not in self.network.state_dict().keys() and key.startswith(\"module.\"):\n                key = key[7:]\n            new_state_dict[key] = value\n\n        self.my_init_kwargs = checkpoint[\"init_args\"]\n        self.current_epoch = checkpoint[\"current_epoch\"]\n        self.logger.load_checkpoint(checkpoint[\"logging\"])\n        self._best_ema = checkpoint[\"_best_ema\"]\n        self.inference_allowed_mirroring_axes = (\n            checkpoint[\"inference_allowed_mirroring_axes\"]\n            if \"inference_allowed_mirroring_axes\" in checkpoint.keys()\n            else self.inference_allowed_mirroring_axes\n        )\n\n        # messing with state dict naming schemes. Facepalm.\n        if self.is_ddp:\n            if isinstance(self.network.module, OptimizedModule):\n                self.network.module._orig_mod.load_state_dict(new_state_dict)\n            else:\n                self.network.module.load_state_dict(new_state_dict)\n        else:\n            if isinstance(self.network, OptimizedModule):\n                self.network._orig_mod.load_state_dict(new_state_dict)\n            else:\n                self.network.load_state_dict(new_state_dict)\n\n        # it's fine to do this every time we load because configure_optimizers will be a no-op if the correct optimizer\n        # and lr scheduler are already set up\n        if self.current_epoch < self.warmup_duration_whole_net:\n            self.optimizer, self.lr_scheduler = self.configure_optimizers(\"warmup_all\")\n        else:\n            self.optimizer, self.lr_scheduler = self.configure_optimizers(\"train\")\n\n        self.optimizer.load_state_dict(checkpoint[\"optimizer_state\"])\n        if self.grad_scaler is not None:\n            if checkpoint[\"grad_scaler_state\"] is not None:\n                self.grad_scaler.load_state_dict(checkpoint[\"grad_scaler_state\"])"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/network_architecture/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py",
    "content": "from typing import Union, Tuple, List\nfrom dynamic_network_architectures.building_blocks.helper import get_matching_batchnorm\nfrom torch import nn\n\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\n\n\nclass nnUNetTrainerBN(nnUNetTrainer):\n    @staticmethod\n    def build_network_architecture(architecture_class_name: str,\n                                   arch_init_kwargs: dict,\n                                   arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],\n                                   num_input_channels: int,\n                                   num_output_channels: int,\n                                   enable_deep_supervision: bool = True) -> nn.Module:\n\n        if 'norm_op' not in arch_init_kwargs.keys():\n            raise RuntimeError(\"'norm_op' not found in arch_init_kwargs. This does not look like an architecture \"\n                               \"I can hack BN into. This trainer only works with default nnU-Net architectures.\")\n\n        from pydoc import locate\n        conv_op = locate(arch_init_kwargs['conv_op'])\n        bn_class = get_matching_batchnorm(conv_op)\n        arch_init_kwargs['norm_op'] = bn_class.__module__ + '.' + bn_class.__name__\n        arch_init_kwargs['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True}\n\n        return nnUNetTrainer.build_network_architecture(architecture_class_name,\n                                                        arch_init_kwargs,\n                                                        arch_init_kwargs_req_import,\n                                                        num_input_channels,\n                                                        num_output_channels, enable_deep_supervision)\n\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py",
    "content": "from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\nimport torch\n\n\nclass nnUNetTrainerNoDeepSupervision(nnUNetTrainer):\n    def __init__(\n        self,\n        plans: dict,\n        configuration: str,\n        fold: int,\n        dataset_json: dict,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.enable_deep_supervision = False\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/optimizer/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py",
    "content": "import torch\nfrom torch.optim import Adam, AdamW\n\nfrom nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\n\n\nclass nnUNetTrainerAdam(nnUNetTrainer):\n    def configure_optimizers(self):\n        optimizer = AdamW(self.network.parameters(),\n                          lr=self.initial_lr,\n                          weight_decay=self.weight_decay,\n                          amsgrad=True)\n        # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,\n        #                             momentum=0.99, nesterov=True)\n        lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)\n        return optimizer, lr_scheduler\n\n\nclass nnUNetTrainerVanillaAdam(nnUNetTrainer):\n    def configure_optimizers(self):\n        optimizer = Adam(self.network.parameters(),\n                         lr=self.initial_lr,\n                         weight_decay=self.weight_decay)\n        # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,\n        #                             momentum=0.99, nesterov=True)\n        lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)\n        return optimizer, lr_scheduler\n\n\nclass nnUNetTrainerVanillaAdam1en3(nnUNetTrainerVanillaAdam):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.initial_lr = 1e-3\n\n\nclass nnUNetTrainerVanillaAdam3en4(nnUNetTrainerVanillaAdam):\n    # https://twitter.com/karpathy/status/801621764144971776?lang=en\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.initial_lr = 3e-4\n\n\nclass nnUNetTrainerAdam1en3(nnUNetTrainerAdam):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.initial_lr = 1e-3\n\n\nclass nnUNetTrainerAdam3en4(nnUNetTrainerAdam):\n    # https://twitter.com/karpathy/status/801621764144971776?lang=en\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.initial_lr = 3e-4\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py",
    "content": "import torch\n\nfrom nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\ntry:\n    from adan_pytorch import Adan\nexcept ImportError:\n    Adan = None\n\n\nclass nnUNetTrainerAdan(nnUNetTrainer):\n    def configure_optimizers(self):\n        if Adan is None:\n            raise RuntimeError('This trainer requires adan_pytorch to be installed, install with \"pip install adan-pytorch\"')\n        optimizer = Adan(self.network.parameters(),\n                         lr=self.initial_lr,\n                         # betas=(0.02, 0.08, 0.01), defaults\n                         weight_decay=self.weight_decay)\n        # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,\n        #                             momentum=0.99, nesterov=True)\n        lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)\n        return optimizer, lr_scheduler\n\n\nclass nnUNetTrainerAdan1en3(nnUNetTrainerAdan):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.initial_lr = 1e-3\n\n\nclass nnUNetTrainerAdan3en4(nnUNetTrainerAdan):\n    # https://twitter.com/karpathy/status/801621764144971776?lang=en\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.initial_lr = 3e-4\n\n\nclass nnUNetTrainerAdan1en1(nnUNetTrainerAdan):\n    # this trainer makes no sense -> nan!\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.initial_lr = 1e-1\n\n\nclass nnUNetTrainerAdanCosAnneal(nnUNetTrainerAdan):\n    # def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n    #              device: torch.device = torch.device('cuda')):\n    #     super().__init__(plans, configuration, fold, dataset_json, device)\n    #     self.num_epochs = 15\n\n    def configure_optimizers(self):\n        if Adan is None:\n            raise RuntimeError('This trainer requires adan_pytorch to be installed, install with \"pip install adan-pytorch\"')\n        optimizer = Adan(self.network.parameters(),\n                         lr=self.initial_lr,\n                         # betas=(0.02, 0.08, 0.01), defaults\n                         weight_decay=self.weight_decay)\n        # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,\n        #                             momentum=0.99, nesterov=True)\n        lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs)\n        return optimizer, lr_scheduler\n\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/sampling/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py",
    "content": "import numpy as np\nimport torch\nfrom torch import distributed as dist\n\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\n\n\nclass nnUNetTrainer_probabilisticOversampling(nnUNetTrainer):\n    \"\"\"\n    sampling of foreground happens randomly and not for the last 33% of samples in a batch\n    since most trainings happen with batch size 2 and nnunet guarantees at least one fg sample, effectively this can\n    be 50%\n    Here we compute the actual oversampling percentage used by nnUNetTrainer in order to be as consistent as possible.\n    If we switch to this oversampling then we can keep it at a constant 0.33 or whatever.\n    \"\"\"\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.probabilistic_oversampling = True\n        self.oversample_foreground_percent = float(np.mean(\n            [not sample_idx < round(self.configuration_manager.batch_size * (1 - self.oversample_foreground_percent))\n             for sample_idx in range(self.configuration_manager.batch_size)]))\n        self.print_to_log_file(f\"self.oversample_foreground_percent {self.oversample_foreground_percent}\")\n\n    def _set_batch_size_and_oversample(self):\n        if not self.is_ddp:\n            # set batch size to what the plan says, leave oversample untouched\n            self.batch_size = self.configuration_manager.batch_size\n        else:\n            # batch size is distributed over DDP workers and we need to change oversample_percent for each worker\n\n            world_size = dist.get_world_size()\n            my_rank = dist.get_rank()\n\n            global_batch_size = self.configuration_manager.batch_size\n            assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \\\n                                                    'GPUs... Duh.'\n\n            batch_size_per_GPU = [global_batch_size // world_size] * world_size\n            batch_size_per_GPU = [batch_size_per_GPU[i] + 1\n                                  if (batch_size_per_GPU[i] * world_size + i) < global_batch_size\n                                  else batch_size_per_GPU[i]\n                                  for i in range(len(batch_size_per_GPU))]\n            assert sum(batch_size_per_GPU) == global_batch_size\n            print(\"worker\", my_rank, \"batch_size\", batch_size_per_GPU[my_rank])\n            print(\"worker\", my_rank, \"oversample\", self.oversample_foreground_percent)\n\n            self.batch_size = batch_size_per_GPU[my_rank]\n\n\nclass nnUNetTrainer_probabilisticOversampling_033(nnUNetTrainer_probabilisticOversampling):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.oversample_foreground_percent = 0.33\n    \n    \nclass nnUNetTrainer_probabilisticOversampling_010(nnUNetTrainer_probabilisticOversampling):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.oversample_foreground_percent = 0.1\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/training_length/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py",
    "content": "import torch\n\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\n\n\nclass nnUNetTrainer_5epochs(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        \"\"\"used for debugging plans etc\"\"\"\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 5\n\n\nclass nnUNetTrainer_1epoch(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        \"\"\"used for debugging plans etc\"\"\"\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 1\n\n\nclass nnUNetTrainer_10epochs(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        \"\"\"used for debugging plans etc\"\"\"\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 10\n\n\nclass nnUNetTrainer_20epochs(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 20\n\n\nclass nnUNetTrainer_50epochs(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 50\n\n\nclass nnUNetTrainer_100epochs(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 100\n\n\nclass nnUNetTrainer_250epochs(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 250\n\n\nclass nnUNetTrainer_500epochs(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 500\n\n\nclass nnUNetTrainer_750epochs(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 750\n\n\nclass nnUNetTrainer_2000epochs(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 2000\n\n    \nclass nnUNetTrainer_4000epochs(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 4000\n\n\nclass nnUNetTrainer_8000epochs(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,\n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 8000\n"
  },
  {
    "path": "nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py",
    "content": "import torch\n\nfrom nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer\n\n\nclass nnUNetTrainer_250epochs_NoMirroring(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, \n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 250\n\n    def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):\n        rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \\\n            super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()\n        mirror_axes = None\n        self.inference_allowed_mirroring_axes = None\n        return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes\n\n\nclass nnUNetTrainer_2000epochs_NoMirroring(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, \n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 2000\n\n    def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):\n        rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \\\n            super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()\n        mirror_axes = None\n        self.inference_allowed_mirroring_axes = None\n        return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes\n\n    \nclass nnUNetTrainer_4000epochs_NoMirroring(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, \n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 4000\n\n    def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):\n        rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \\\n            super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()\n        mirror_axes = None\n        self.inference_allowed_mirroring_axes = None\n        return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes\n\n\nclass nnUNetTrainer_8000epochs_NoMirroring(nnUNetTrainer):\n    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, \n                 device: torch.device = torch.device('cuda')):\n        super().__init__(plans, configuration, fold, dataset_json, device)\n        self.num_epochs = 8000\n\n    def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):\n        rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \\\n            super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()\n        mirror_axes = None\n        self.inference_allowed_mirroring_axes = None\n        return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes\n\n"
  },
  {
    "path": "nnunetv2/utilities/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/utilities/collate_outputs.py",
    "content": "from typing import List\n\nimport numpy as np\n\n\ndef collate_outputs(outputs: List[dict]):\n    \"\"\"\n    used to collate default train_step and validation_step outputs. If you want something different then you gotta\n    extend this\n\n    we expect outputs to be a list of dictionaries where each of the dict has the same set of keys\n    \"\"\"\n    collated = {}\n    for k in outputs[0].keys():\n        if np.isscalar(outputs[0][k]):\n            collated[k] = [o[k] for o in outputs]\n        elif isinstance(outputs[0][k], np.ndarray):\n            collated[k] = np.vstack([o[k][None] for o in outputs])\n        elif isinstance(outputs[0][k], list):\n            collated[k] = [item for o in outputs for item in o[k]]\n        else:\n            raise ValueError(f'Cannot collate input of type {type(outputs[0][k])}. '\n                             f'Modify collate_outputs to add this functionality')\n    return collated"
  },
  {
    "path": "nnunetv2/utilities/crossval_split.py",
    "content": "from typing import List\n\nimport numpy as np\nfrom sklearn.model_selection import KFold\n\n\ndef generate_crossval_split(train_identifiers: List[str], seed=12345, n_splits=5) -> List[dict[str, List[str]]]:\n    splits = []\n    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=seed)\n    for i, (train_idx, test_idx) in enumerate(kfold.split(train_identifiers)):\n        train_keys = np.array(train_identifiers)[train_idx]\n        test_keys = np.array(train_identifiers)[test_idx]\n        splits.append({})\n        splits[-1]['train'] = list(train_keys)\n        splits[-1]['val'] = list(test_keys)\n    return splits\n"
  },
  {
    "path": "nnunetv2/utilities/dataset_name_id_conversion.py",
    "content": "#    Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\nfrom typing import Union\n\nfrom nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results\nfrom batchgenerators.utilities.file_and_folder_operations import *\nimport numpy as np\n\n\ndef find_candidate_datasets(dataset_id: int):\n    startswith = \"Dataset%03.0d\" % dataset_id\n    if nnUNet_preprocessed is not None and isdir(nnUNet_preprocessed):\n        candidates_preprocessed = subdirs(nnUNet_preprocessed, prefix=startswith, join=False)\n    else:\n        candidates_preprocessed = []\n\n    if nnUNet_raw is not None and isdir(nnUNet_raw):\n        candidates_raw = subdirs(nnUNet_raw, prefix=startswith, join=False)\n    else:\n        candidates_raw = []\n\n    candidates_trained_models = []\n    if nnUNet_results is not None and isdir(nnUNet_results):\n        candidates_trained_models += subdirs(nnUNet_results, prefix=startswith, join=False)\n\n    all_candidates = candidates_preprocessed + candidates_raw + candidates_trained_models\n    unique_candidates = np.unique(all_candidates)\n    return unique_candidates\n\n\ndef convert_id_to_dataset_name(dataset_id: int):\n    unique_candidates = find_candidate_datasets(dataset_id)\n    if len(unique_candidates) > 1:\n        raise RuntimeError(\"More than one dataset name found for dataset id %d. Please correct that. (I looked in the \"\n                           \"following folders:\\n%s\\n%s\\n%s\" % (dataset_id, nnUNet_raw, nnUNet_preprocessed, nnUNet_results))\n    if len(unique_candidates) == 0:\n        raise RuntimeError(f\"Could not find a dataset with the ID {dataset_id}. Make sure the requested dataset ID \"\n                           f\"exists and that nnU-Net knows where raw and preprocessed data are located \"\n                           f\"(see Documentation - Installation). Here are your currently defined folders:\\n\"\n                           f\"nnUNet_preprocessed={os.environ.get('nnUNet_preprocessed') if os.environ.get('nnUNet_preprocessed') is not None else 'None'}\\n\"\n                           f\"nnUNet_results={os.environ.get('nnUNet_results') if os.environ.get('nnUNet_results') is not None else 'None'}\\n\"\n                           f\"nnUNet_raw={os.environ.get('nnUNet_raw') if os.environ.get('nnUNet_raw') is not None else 'None'}\\n\"\n                           f\"If something is not right, adapt your environment variables.\")\n    return unique_candidates[0]\n\n\ndef convert_dataset_name_to_id(dataset_name: str):\n    assert dataset_name.startswith(\"Dataset\")\n    dataset_id = int(dataset_name[7:10])\n    return dataset_id\n\n\ndef maybe_convert_to_dataset_name(dataset_name_or_id: Union[int, str]) -> str:\n    if isinstance(dataset_name_or_id, str) and dataset_name_or_id.startswith(\"Dataset\"):\n        return dataset_name_or_id\n    if isinstance(dataset_name_or_id, str):\n        try:\n            dataset_name_or_id = int(dataset_name_or_id)\n        except ValueError:\n            raise ValueError(\"dataset_name_or_id was a string and did not start with 'Dataset' so we tried to \"\n                             \"convert it to a dataset ID (int). That failed, however. Please give an integer number \"\n                             \"('1', '2', etc) or a correct dataset name. Your input: %s\" % dataset_name_or_id)\n    return convert_id_to_dataset_name(dataset_name_or_id)\n"
  },
  {
    "path": "nnunetv2/utilities/ddp_allgather.py",
    "content": "#    Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\nfrom typing import Any, Optional, Tuple\n\nimport torch\nfrom torch import distributed\n\n\ndef print_if_rank0(*args):\n    if distributed.get_rank() == 0:\n        print(*args)\n\n\nclass AllGatherGrad(torch.autograd.Function):\n    # stolen from pytorch lightning\n    @staticmethod\n    def forward(\n        ctx: Any,\n        tensor: torch.Tensor,\n        group: Optional[\"torch.distributed.ProcessGroup\"] = None,\n    ) -> torch.Tensor:\n        ctx.group = group\n\n        gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]\n\n        torch.distributed.all_gather(gathered_tensor, tensor, group=group)\n        gathered_tensor = torch.stack(gathered_tensor, dim=0)\n\n        return gathered_tensor\n\n    @staticmethod\n    def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:\n        grad_output = torch.cat(grad_output)\n\n        torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)\n\n        return grad_output[torch.distributed.get_rank()], None\n\n"
  },
  {
    "path": "nnunetv2/utilities/default_n_proc_DA.py",
    "content": "import subprocess\nimport os\n\n\ndef get_allowed_n_proc_DA():\n    \"\"\"\n    This function is used to set the number of processes used on different Systems. It is specific to our cluster\n    infrastructure at DKFZ. You can modify it to suit your needs. Everything is allowed.\n\n    IMPORTANT: if the environment variable nnUNet_n_proc_DA is set it will overwrite anything in this script\n    (see first line).\n\n    Interpret the output as the number of processes used for data augmentation PER GPU.\n\n    The way it is implemented here is simply a look up table. We know the hostnames, CPU and GPU configurations of our\n    systems and set the numbers accordingly. For example, a system with 4 GPUs and 48 threads can use 12 threads per\n    GPU without overloading the CPU (technically 11 because we have a main process as well), so that's what we use.\n    \"\"\"\n\n    if 'nnUNet_n_proc_DA' in os.environ.keys():\n        use_this = int(os.environ['nnUNet_n_proc_DA'])\n    else:\n        hostname = subprocess.getoutput(['hostname'])\n        if hostname in ['Fabian', 'isensee-']:\n            use_this = 12\n        elif hostname in ['hdf19-gpu16', 'hdf19-gpu17', 'hdf19-gpu18', 'hdf19-gpu19', 'e230-AMDworkstation']:\n            use_this = 16\n        elif hostname.startswith('e230-dgx1'):\n            use_this = 10\n        elif hostname.startswith('hdf18-gpu') or hostname.startswith('e132-comp'):\n            use_this = 16\n        elif hostname.startswith('e230-dgx2'):\n            use_this = 6\n        elif hostname.startswith('e230-dgxa100-'):\n            use_this = 28\n        elif hostname.startswith('e230-thinka100-'):\n            use_this = 20\n        elif hostname.startswith('lsf22-gpu'):\n            use_this = 28\n        elif hostname.startswith('hdf19-gpu') or hostname.startswith('e071-gpu'):\n            use_this = 12\n        elif 'superh200' in hostname.lower():\n            use_this = 48\n        elif 'superl40s' in hostname.lower():\n            use_this = 28\n        else:\n            use_this = 12  # default value\n\n    use_this = min(use_this, os.cpu_count())\n    return use_this\n"
  },
  {
    "path": "nnunetv2/utilities/file_path_utilities.py",
    "content": "from multiprocessing import Pool\nfrom typing import Union, Tuple\nimport numpy as np\nfrom batchgenerators.utilities.file_and_folder_operations import *\n\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.paths import nnUNet_results\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\n\n\ndef convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration):\n    return f'{trainer_name}__{plans_identifier}__{configuration}'\n\n\ndef convert_identifier_to_trainer_plans_config(identifier: str):\n    return os.path.basename(identifier).split('__')\n\n\ndef get_output_folder(dataset_name_or_id: Union[str, int], trainer_name: str = 'nnUNetTrainer',\n                      plans_identifier: str = 'nnUNetPlans', configuration: str = '3d_fullres',\n                      fold: Union[str, int] = None) -> str:\n    tmp = join(nnUNet_results, maybe_convert_to_dataset_name(dataset_name_or_id),\n               convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration))\n    if fold is not None:\n        tmp = join(tmp, f'fold_{fold}')\n    return tmp\n\n\ndef parse_dataset_trainer_plans_configuration_from_path(path: str):\n    folders = split_path(path)\n    # this here can be a little tricky because we are making assumptions. Let's hope this never fails lol\n\n    # safer to make this depend on two conditions, the fold_x and the DatasetXXX\n    # first let's see if some fold_X is present\n    fold_x_present = [i.startswith('fold_') for i in folders]\n    if any(fold_x_present):\n        idx = fold_x_present.index(True)\n        # OK now two entries before that there should be DatasetXXX\n        assert len(folders[:idx]) >= 2, 'Bad path, cannot extract what I need. Your path needs to be at least ' \\\n                                        'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'\n        if folders[idx - 2].startswith('Dataset'):\n            split = folders[idx - 1].split('__')\n            assert len(split) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \\\n                                        'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'\n            return folders[idx - 2], *split\n    else:\n        # we can only check for dataset followed by a string that is separable into three strings by splitting with '__'\n        # look for DatasetXXX\n        dataset_folder = [i.startswith('Dataset') for i in folders]\n        if any(dataset_folder):\n            idx = dataset_folder.index(True)\n            assert len(folders) >= (idx + 1), 'Bad path, cannot extract what I need. Your path needs to be at least ' \\\n                                        'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'\n            split = folders[idx + 1].split('__')\n            assert len(split) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \\\n                                       'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'\n            return folders[idx], *split\n\n\ndef get_ensemble_name(model1_folder, model2_folder, folds: Tuple[int, ...]):\n    identifier = 'ensemble___' + os.path.basename(model1_folder) + '___' + \\\n                 os.path.basename(model2_folder) + '___' + folds_tuple_to_string(folds)\n    return identifier\n\n\ndef get_ensemble_name_from_d_tr_c(dataset, tr1, p1, c1, tr2, p2, c2, folds: Tuple[int, ...]):\n    model1_folder = get_output_folder(dataset, tr1, p1, c1)\n    model2_folder = get_output_folder(dataset, tr2, p2, c2)\n\n    get_ensemble_name(model1_folder, model2_folder, folds)\n\n\ndef convert_ensemble_folder_to_model_identifiers_and_folds(ensemble_folder: str):\n    prefix, *models, folds = os.path.basename(ensemble_folder).split('___')\n    return models, folds\n\n\ndef folds_tuple_to_string(folds: Union[List[int], Tuple[int, ...]]):\n    s = str(folds[0])\n    for f in folds[1:]:\n        s += f\"_{f}\"\n    return s\n\n\ndef folds_string_to_tuple(folds_string: str):\n    folds = folds_string.split('_')\n    res = []\n    for f in folds:\n        try:\n            res.append(int(f))\n        except ValueError:\n            res.append(f)\n    return res\n\n\ndef check_workers_alive_and_busy(export_pool: Pool, worker_list: List, results_list: List, allowed_num_queued: int = 0):\n    \"\"\"\n\n    returns True if the number of results that are not ready is greater than the number of available workers + allowed_num_queued\n    \"\"\"\n    alive = [i.is_alive() for i in worker_list]\n    if not all(alive):\n        raise RuntimeError('Some background workers are no longer alive')\n\n    not_ready = [not i.ready() for i in results_list]\n    if sum(not_ready) >= (len(export_pool._pool) + allowed_num_queued):\n        return True\n    return False\n\n\nif __name__ == '__main__':\n    ### well at this point I could just write tests...\n    path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres'\n    print(parse_dataset_trainer_plans_configuration_from_path(path))\n    path = 'Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres'\n    print(parse_dataset_trainer_plans_configuration_from_path(path))\n    path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres/fold_all'\n    print(parse_dataset_trainer_plans_configuration_from_path(path))\n    try:\n        path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/'\n        print(parse_dataset_trainer_plans_configuration_from_path(path))\n    except AssertionError:\n        print('yayy, assertion works')\n"
  },
  {
    "path": "nnunetv2/utilities/find_class_by_name.py",
    "content": "import importlib\nimport pkgutil\n\nfrom batchgenerators.utilities.file_and_folder_operations import *\n\n\ndef recursive_find_python_class(folder: str, class_name: str, current_module: str):\n    tr = None\n    for importer, modname, ispkg in pkgutil.iter_modules([folder]):\n        # print(modname, ispkg)\n        if not ispkg:\n            m = importlib.import_module(current_module + \".\" + modname)\n            if hasattr(m, class_name):\n                tr = getattr(m, class_name)\n                break\n\n    if tr is None:\n        for importer, modname, ispkg in pkgutil.iter_modules([folder]):\n            if ispkg:\n                next_current_module = current_module + \".\" + modname\n                tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module)\n            if tr is not None:\n                break\n    return tr\n"
  },
  {
    "path": "nnunetv2/utilities/get_network_from_plans.py",
    "content": "import pydoc\nimport warnings\nfrom typing import Union\n\nfrom nnunetv2.utilities.find_class_by_name import recursive_find_python_class\nfrom batchgenerators.utilities.file_and_folder_operations import join\n\n\ndef get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels,\n                           allow_init=True, deep_supervision: Union[bool, None] = None):\n    network_class = arch_class_name\n    architecture_kwargs = dict(**arch_kwargs)\n    for ri in arch_kwargs_req_import:\n        if architecture_kwargs[ri] is not None:\n            architecture_kwargs[ri] = pydoc.locate(architecture_kwargs[ri])\n\n    nw_class = pydoc.locate(network_class)\n    # sometimes things move around, this makes it so that we can at least recover some of that\n    if nw_class is None:\n        warnings.warn(f'Network class {network_class} not found. Attempting to locate it within '\n                      f'dynamic_network_architectures.architectures...')\n        import dynamic_network_architectures\n        nw_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], \"architectures\"),\n                                               network_class.split(\".\")[-1],\n                                               'dynamic_network_architectures.architectures')\n        if nw_class is not None:\n            print(f'FOUND IT: {nw_class}')\n        else:\n            raise ImportError('Network class could not be found, please check/correct your plans file')\n\n    if deep_supervision is not None:\n        architecture_kwargs['deep_supervision'] = deep_supervision\n\n    network = nw_class(\n        input_channels=input_channels,\n        num_classes=output_channels,\n        **architecture_kwargs\n    )\n\n    if hasattr(network, 'initialize') and allow_init:\n        network.apply(network.initialize)\n\n    return network\n\nif __name__ == \"__main__\":\n    import torch\n\n    model = get_network_from_plans(\n        arch_class_name=\"dynamic_network_architectures.architectures.unet.ResidualEncoderUNet\",\n        arch_kwargs={\n            \"n_stages\": 7,\n            \"features_per_stage\": [32, 64, 128, 256, 512, 512, 512],\n            \"conv_op\": \"torch.nn.modules.conv.Conv2d\",\n            \"kernel_sizes\": [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]],\n            \"strides\": [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]],\n            \"n_blocks_per_stage\": [1, 3, 4, 6, 6, 6, 6],\n            \"n_conv_per_stage_decoder\": [1, 1, 1, 1, 1, 1],\n            \"conv_bias\": True,\n            \"norm_op\": \"torch.nn.modules.instancenorm.InstanceNorm2d\",\n            \"norm_op_kwargs\": {\"eps\": 1e-05, \"affine\": True},\n            \"dropout_op\": None,\n            \"dropout_op_kwargs\": None,\n            \"nonlin\": \"torch.nn.LeakyReLU\",\n            \"nonlin_kwargs\": {\"inplace\": True},\n        },\n        arch_kwargs_req_import=[\"conv_op\", \"norm_op\", \"dropout_op\", \"nonlin\"],\n        input_channels=1,\n        output_channels=4,\n        allow_init=True,\n        deep_supervision=True,\n    )\n    data = torch.rand((8, 1, 256, 256))\n    target = torch.rand(size=(8, 1, 256, 256))\n    outputs = model(data) # this should be a list of torch.Tensor"
  },
  {
    "path": "nnunetv2/utilities/helpers.py",
    "content": "import torch\n\n\ndef softmax_helper_dim0(x: torch.Tensor) -> torch.Tensor:\n    return torch.softmax(x, 0)\n\n\ndef softmax_helper_dim1(x: torch.Tensor) -> torch.Tensor:\n    return torch.softmax(x, 1)\n\n\ndef empty_cache(device: torch.device):\n    if device.type == 'cuda':\n        torch.cuda.empty_cache()\n    elif device.type == 'mps':\n        from torch import mps\n        mps.empty_cache()\n    else:\n        pass\n\n\nclass dummy_context(object):\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        pass\n"
  },
  {
    "path": "nnunetv2/utilities/json_export.py",
    "content": "from collections.abc import Iterable\n\nimport numpy as np\nimport torch\n\n\ndef recursive_fix_for_json_export(my_dict: dict):\n    # json is ... a very nice thing to have\n    # 'cannot serialize object of type bool_/int64/float64'. Apart from that of course...\n    keys = list(my_dict.keys())  # cannot iterate over keys() if we change keys....\n    for k in keys:\n        if isinstance(k, (np.int64, np.int32, np.int8, np.uint8)):\n            tmp = my_dict[k]\n            del my_dict[k]\n            my_dict[int(k)] = tmp\n            del tmp\n            k = int(k)\n\n        if isinstance(my_dict[k], dict):\n            recursive_fix_for_json_export(my_dict[k])\n        elif isinstance(my_dict[k], np.ndarray):\n            assert my_dict[k].ndim == 1, 'only 1d arrays are supported'\n            my_dict[k] = fix_types_iterable(my_dict[k], output_type=list)\n        elif isinstance(my_dict[k], (np.bool_,)):\n            my_dict[k] = bool(my_dict[k])\n        elif isinstance(my_dict[k], (np.int64, np.int32, np.int8, np.uint8)):\n            my_dict[k] = int(my_dict[k])\n        elif isinstance(my_dict[k], (np.float32, np.float64, np.float16)):\n            my_dict[k] = float(my_dict[k])\n        elif isinstance(my_dict[k], list):\n            my_dict[k] = fix_types_iterable(my_dict[k], output_type=type(my_dict[k]))\n        elif isinstance(my_dict[k], tuple):\n            my_dict[k] = fix_types_iterable(my_dict[k], output_type=tuple)\n        elif isinstance(my_dict[k], torch.device):\n            my_dict[k] = str(my_dict[k])\n        else:\n            pass  # pray it can be serialized\n\n\ndef fix_types_iterable(iterable, output_type):\n    # this sh!t is hacky as hell and will break if you use it for anything outside nnunet. Keep your hands off of this.\n    out = []\n    for i in iterable:\n        if type(i) in (np.int64, np.int32, np.int8, np.uint8):\n            out.append(int(i))\n        elif isinstance(i, dict):\n            recursive_fix_for_json_export(i)\n            out.append(i)\n        elif type(i) in (np.float32, np.float64, np.float16):\n            out.append(float(i))\n        elif type(i) in (np.bool_,):\n            out.append(bool(i))\n        elif isinstance(i, str):\n            out.append(i)\n        elif isinstance(i, Iterable):\n            # print('recursive call on', i, type(i))\n            out.append(fix_types_iterable(i, type(i)))\n        else:\n            out.append(i)\n    return output_type(out)\n"
  },
  {
    "path": "nnunetv2/utilities/label_handling/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/utilities/label_handling/label_handling.py",
    "content": "from __future__ import annotations\nfrom time import time\nfrom typing import Union, List, Tuple, Type\n\nimport numpy as np\nimport torch\nfrom acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice, insert_crop_into_image\nfrom batchgenerators.utilities.file_and_folder_operations import join\n\nimport nnunetv2\nfrom nnunetv2.utilities.find_class_by_name import recursive_find_python_class\nfrom nnunetv2.utilities.helpers import softmax_helper_dim0\n\nfrom typing import TYPE_CHECKING\n\n# see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/\nif TYPE_CHECKING:\n    from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager\n\n\nclass LabelManager(object):\n    def __init__(self, label_dict: dict, regions_class_order: Union[List[int], None], force_use_labels: bool = False,\n                 inference_nonlin=None):\n        self._sanity_check(label_dict)\n        self.label_dict = label_dict\n        self.regions_class_order = regions_class_order\n        self._force_use_labels = force_use_labels\n\n        if force_use_labels:\n            self._has_regions = False\n        else:\n            self._has_regions: bool = any(\n                [isinstance(i, (tuple, list)) and len(i) > 1 for i in self.label_dict.values()])\n\n        self._ignore_label: Union[None, int] = self._determine_ignore_label()\n        self._all_labels: List[int] = self._get_all_labels()\n\n        self._regions: Union[None, List[Union[int, Tuple[int, ...]]]] = self._get_regions()\n\n        if self.has_ignore_label:\n            assert self.ignore_label == max(\n                self.all_labels) + 1, 'If you use the ignore label it must have the highest ' \\\n                                      'label value! It cannot be 0 or in between other labels. ' \\\n                                      'Sorry bro.'\n\n        if inference_nonlin is None:\n            self.inference_nonlin = torch.sigmoid if self.has_regions else softmax_helper_dim0\n        else:\n            self.inference_nonlin = inference_nonlin\n\n    def _sanity_check(self, label_dict: dict):\n        if not 'background' in label_dict.keys():\n            raise RuntimeError('Background label not declared (remember that this should be label 0!)')\n        bg_label = label_dict['background']\n        if isinstance(bg_label, (tuple, list)):\n            raise RuntimeError(f\"Background label must be 0. Not a list. Not a tuple. Your background label: {bg_label}\")\n        assert int(bg_label) == 0, f\"Background label must be 0. Your background label: {bg_label}\"\n        # not sure if we want to allow regions that contain background. I don't immediately see how this could cause\n        # problems so we allow it for now. That doesn't mean that this is explicitly supported. It could be that this\n        # just crashes.\n\n    def _get_all_labels(self) -> List[int]:\n        all_labels = []\n        for k, r in self.label_dict.items():\n            # ignore label is not going to be used, hence the name. Duh.\n            if k == 'ignore':\n                continue\n            if isinstance(r, (tuple, list)):\n                for ri in r:\n                    all_labels.append(int(ri))\n            else:\n                all_labels.append(int(r))\n        all_labels = list(np.unique(all_labels))\n        all_labels.sort()\n        return all_labels\n\n    def _get_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]:\n        if not self._has_regions or self._force_use_labels:\n            return None\n        else:\n            assert self.regions_class_order is not None, 'if region-based training is requested then you need to ' \\\n                                                         'define regions_class_order!'\n            regions = []\n            for k, r in self.label_dict.items():\n                # ignore ignore label\n                if k == 'ignore':\n                    continue\n                # ignore regions that are background\n                if (np.isscalar(r) and r == 0) \\\n                        or \\\n                        (isinstance(r, (tuple, list)) and len(np.unique(r)) == 1 and np.unique(r)[0] == 0):\n                    continue\n                if isinstance(r, list):\n                    r = tuple(r)\n                regions.append(r)\n            assert len(self.regions_class_order) == len(regions), 'regions_class_order must have as ' \\\n                                                                  'many entries as there are ' \\\n                                                                  'regions'\n            return regions\n\n    def _determine_ignore_label(self) -> Union[None, int]:\n        ignore_label = self.label_dict.get('ignore')\n        if ignore_label is not None:\n            assert isinstance(ignore_label, int), f'Ignore label has to be an integer. It cannot be a region ' \\\n                                                  f'(list/tuple). Got {type(ignore_label)}.'\n        return ignore_label\n\n    @property\n    def has_regions(self) -> bool:\n        return self._has_regions\n\n    @property\n    def has_ignore_label(self) -> bool:\n        return self.ignore_label is not None\n\n    @property\n    def all_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]:\n        return self._regions\n\n    @property\n    def all_labels(self) -> List[int]:\n        return self._all_labels\n\n    @property\n    def ignore_label(self) -> Union[None, int]:\n        return self._ignore_label\n\n    def apply_inference_nonlin(self, logits: Union[np.ndarray, torch.Tensor]) -> \\\n            Union[np.ndarray, torch.Tensor]:\n        \"\"\"\n        logits has to have shape (c, x, y(, z)) where c is the number of classes/regions\n        \"\"\"\n        if isinstance(logits, np.ndarray):\n            logits = torch.from_numpy(logits)\n\n        with torch.no_grad():\n            # softmax etc is not implemented for half\n            logits = logits.float()\n            probabilities = self.inference_nonlin(logits)\n\n        return probabilities\n\n    @torch.inference_mode()\n    def convert_probabilities_to_segmentation(self, predicted_probabilities: Union[np.ndarray, torch.Tensor]) -> \\\n            Union[np.ndarray, torch.Tensor]:\n        \"\"\"\n        assumes that inference_nonlinearity was already applied!\n\n        predicted_probabilities has to have shape (c, x, y(, z)) where c is the number of classes/regions\n        \"\"\"\n        if not isinstance(predicted_probabilities, (np.ndarray, torch.Tensor)):\n            raise RuntimeError(f\"Unexpected input type. Expected np.ndarray or torch.Tensor,\"\n                               f\" got {type(predicted_probabilities)}\")\n\n        if self.has_regions:\n            assert self.regions_class_order is not None, 'if region-based training is requested then you need to ' \\\n                                                         'define regions_class_order!'\n            # check correct number of outputs\n        assert predicted_probabilities.shape[0] == self.num_segmentation_heads, \\\n            f'unexpected number of channels in predicted_probabilities. Expected {self.num_segmentation_heads}, ' \\\n            f'got {predicted_probabilities.shape[0]}. Remember that predicted_probabilities should have shape ' \\\n            f'(c, x, y(, z)).'\n\n        if self.has_regions:\n            if isinstance(predicted_probabilities, np.ndarray):\n                segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.uint16)\n            else:\n                # no uint16 in torch\n                segmentation = torch.zeros(predicted_probabilities.shape[1:], dtype=torch.int16,\n                                           device=predicted_probabilities.device)\n            for i, c in enumerate(self.regions_class_order):\n                segmentation[predicted_probabilities[i] > 0.5] = c\n        else:\n            # numpy is faster than torch. :facepalm:\n            is_numpy = isinstance(predicted_probabilities, np.ndarray)\n            if not is_numpy:\n                predicted_probabilities = predicted_probabilities.numpy()\n            segmentation = predicted_probabilities.argmax(0)\n            if not is_numpy:\n                segmentation = torch.from_numpy(segmentation)\n\n        return segmentation\n\n    @torch.inference_mode()\n    def convert_logits_to_segmentation(self, predicted_logits: Union[np.ndarray, torch.Tensor]) -> \\\n            Union[np.ndarray, torch.Tensor]:\n        input_is_numpy = isinstance(predicted_logits, np.ndarray)\n        # we can skip this step if we do not have region. Argmax is the same between logits or probabilities\n        if self.has_regions:\n            probabilities = self.apply_inference_nonlin(predicted_logits)\n        else:\n            probabilities = predicted_logits\n        if input_is_numpy and isinstance(probabilities, torch.Tensor):\n            probabilities = probabilities.cpu().numpy()\n        return self.convert_probabilities_to_segmentation(probabilities)\n\n    def revert_cropping_on_probabilities(self, predicted_probabilities: Union[torch.Tensor, np.ndarray],\n                                         bbox: List[List[int]],\n                                         original_shape: Union[List[int], Tuple[int, ...]]):\n        \"\"\"\n        ONLY USE THIS WITH PROBABILITIES, DO NOT USE LOGITS AND DO NOT USE FOR SEGMENTATION MAPS!!!\n\n        predicted_probabilities must be (c, x, y(, z))\n\n        Why do we do this here? Well if we pad probabilities we need to make sure that convert_logits_to_segmentation\n        correctly returns background in the padded areas. Also we want to ba able to look at the padded probabilities\n        and not have strange artifacts.\n        Only LabelManager knows how this needs to be done. So let's let him/her do it, ok?\n        \"\"\"\n        # revert cropping\n        probs_reverted_cropping = np.zeros((predicted_probabilities.shape[0], *original_shape),\n                                           dtype=predicted_probabilities.dtype) \\\n            if isinstance(predicted_probabilities, np.ndarray) else \\\n            torch.zeros((predicted_probabilities.shape[0], *original_shape), dtype=predicted_probabilities.dtype)\n\n        if not self.has_regions:\n            probs_reverted_cropping[0] = 1\n\n        probs_reverted_cropping = insert_crop_into_image(probs_reverted_cropping, predicted_probabilities, bbox)\n        return probs_reverted_cropping\n\n    @staticmethod\n    def filter_background(classes_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]]):\n        # heck yeah\n        # This is definitely taking list comprehension too far. Enjoy.\n        return [i for i in classes_or_regions if\n                ((not isinstance(i, (tuple, list))) and i != 0)\n                or\n                (isinstance(i, (tuple, list)) and not (\n                        len(np.unique(i)) == 1 and np.unique(i)[0] == 0))]\n\n    @property\n    def foreground_regions(self):\n        return self.filter_background(self.all_regions)\n\n    @property\n    def foreground_labels(self):\n        return self.filter_background(self.all_labels)\n\n    @property\n    def num_segmentation_heads(self):\n        if self.has_regions:\n            return len(self.foreground_regions)\n        else:\n            return len(self.all_labels)\n\n\ndef get_labelmanager_class_from_plans(plans: dict) -> Type[LabelManager]:\n    if 'label_manager' not in plans.keys():\n        print('No label manager specified in plans. Using default: LabelManager')\n        return LabelManager\n    else:\n        labelmanager_class = recursive_find_python_class(join(nnunetv2.__path__[0], \"utilities\", \"label_handling\"),\n                                                         plans['label_manager'],\n                                                         current_module=\"nnunetv2.utilities.label_handling\")\n        return labelmanager_class\n\n\ndef convert_labelmap_to_one_hot(segmentation: Union[np.ndarray, torch.Tensor],\n                                all_labels: Union[List, torch.Tensor, np.ndarray, tuple],\n                                output_dtype=None) -> Union[np.ndarray, torch.Tensor]:\n    \"\"\"\n    if output_dtype is None then we use np.uint8/torch.uint8\n    if input is torch.Tensor then output will be on the same device\n\n    np.ndarray is faster than torch.Tensor\n\n    if segmentation is torch.Tensor, this function will be faster if it is LongTensor. If it is somethine else we have\n    to cast which takes time.\n\n    IMPORTANT: This function only works properly if your labels are consecutive integers, so something like 0, 1, 2, 3, ...\n    DO NOT use it with 0, 32, 123, 255, ... or whatever (fix your labels, yo)\n    \"\"\"\n    if isinstance(segmentation, torch.Tensor):\n        result = torch.zeros((len(all_labels), *segmentation.shape),\n                             dtype=output_dtype if output_dtype is not None else (torch.uint8 if max(all_labels) < 255 else torch.uint16),\n                             device=segmentation.device)\n        # variant 1, 2x faster than 2\n        result.scatter_(0, segmentation[None].long(), 1)  # why does this have to be long!?\n        # variant 2, slower than 1\n        # for i, l in enumerate(all_labels):\n        #     result[i] = segmentation == l\n    else:\n        result = np.zeros((len(all_labels), *segmentation.shape),\n                          dtype=output_dtype if output_dtype is not None else (np.uint8 if max(all_labels) < 255 else np.uint16))\n        # variant 1, fastest in my testing\n        for i, l in enumerate(all_labels):\n            result[i] = segmentation == l\n        # variant 2. Takes about twice as long so nah\n        # result = np.eye(len(all_labels))[segmentation].transpose((3, 0, 1, 2))\n    return result\n\n\ndef determine_num_input_channels(plans_manager: PlansManager,\n                                 configuration_or_config_manager: Union[str, ConfigurationManager],\n                                 dataset_json: dict) -> int:\n    if isinstance(configuration_or_config_manager, str):\n        config_manager = plans_manager.get_configuration(configuration_or_config_manager)\n    else:\n        config_manager = configuration_or_config_manager\n\n    label_manager = plans_manager.get_label_manager(dataset_json)\n    num_modalities = len(dataset_json['modality']) if 'modality' in dataset_json.keys() else len(dataset_json['channel_names'])\n\n    # cascade has different number of input channels\n    if config_manager.previous_stage_name is not None:\n        num_label_inputs = len(label_manager.foreground_labels)\n        num_input_channels = num_modalities + num_label_inputs\n    else:\n        num_input_channels = num_modalities\n    return num_input_channels\n\n\nif __name__ == '__main__':\n    # this code used to be able to differentiate variant 1 and 2 to measure time.\n    num_labels = 7\n    seg = np.random.randint(0, num_labels, size=(256, 256, 256), dtype=np.uint8)\n    seg_torch = torch.from_numpy(seg)\n    st = time()\n    onehot_npy = convert_labelmap_to_one_hot(seg, np.arange(num_labels))\n    time_1 = time()\n    onehot_npy2 = convert_labelmap_to_one_hot(seg, np.arange(num_labels))\n    time_2 = time()\n    onehot_torch = convert_labelmap_to_one_hot(seg_torch, np.arange(num_labels))\n    time_torch = time()\n    onehot_torch2 = convert_labelmap_to_one_hot(seg_torch, np.arange(num_labels))\n    time_torch2 = time()\n    print(\n        f'np: {time_1 - st}, np2: {time_2 - time_1}, torch: {time_torch - time_2}, torch2: {time_torch2 - time_torch}')\n    onehot_torch = onehot_torch.numpy()\n    onehot_torch2 = onehot_torch2.numpy()\n    print(np.all(onehot_torch == onehot_npy))\n    print(np.all(onehot_torch2 == onehot_npy))\n"
  },
  {
    "path": "nnunetv2/utilities/network_initialization.py",
    "content": "from torch import nn\n\n\nclass InitWeights_He(object):\n    def __init__(self, neg_slope=1e-2):\n        self.neg_slope = neg_slope\n\n    def __call__(self, module):\n        if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):\n            module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)\n            if module.bias is not None:\n                module.bias = nn.init.constant_(module.bias, 0)\n"
  },
  {
    "path": "nnunetv2/utilities/overlay_plots.py",
    "content": "#    Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\nimport multiprocessing\nfrom multiprocessing.pool import Pool\nfrom typing import Tuple, Union\n\nimport numpy as np\nimport pandas as pd\nfrom batchgenerators.utilities.file_and_folder_operations import *\nfrom nnunetv2.configuration import default_num_processes\nfrom nnunetv2.imageio.base_reader_writer import BaseReaderWriter\nfrom nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json\nfrom nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed\nfrom nnunetv2.training.dataloading.nnunet_dataset import infer_dataset_class, nnUNetBaseDataset\nfrom nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\nfrom nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager\nfrom nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \\\n    get_filenames_of_train_images_and_targets\n\ncolor_cycle = (\n    \"000000\",\n    \"4363d8\",\n    \"f58231\",\n    \"3cb44b\",\n    \"e6194B\",\n    \"911eb4\",\n    \"ffe119\",\n    \"bfef45\",\n    \"42d4f4\",\n    \"f032e6\",\n    \"000075\",\n    \"9A6324\",\n    \"808000\",\n    \"800000\",\n    \"469990\",\n)\n\n\ndef hex_to_rgb(hex: str):\n    assert len(hex) == 6\n    return tuple(int(hex[i:i + 2], 16) for i in (0, 2, 4))\n\n\ndef generate_overlay(input_image: np.ndarray, segmentation: np.ndarray, mapping: dict = None,\n                     color_cycle: Tuple[str, ...] = color_cycle,\n                     overlay_intensity: float = 0.6):\n    \"\"\"\n    image can be 2d greyscale or 2d RGB (color channel in last dimension!)\n\n    Segmentation must be label map of same shape as image (w/o color channels)\n\n    mapping can be label_id -> idx_in_cycle or None\n\n    returned image is scaled to [0, 255] (uint8)!!!\n    \"\"\"\n    # create a copy of image\n    image = np.copy(input_image)\n\n    if image.ndim == 2:\n        image = np.tile(image[:, :, None], (1, 1, 3))\n    elif image.ndim == 3:\n        if image.shape[2] == 1:\n            image = np.tile(image, (1, 1, 3))\n        else:\n            raise RuntimeError(f'if 3d image is given the last dimension must be the color channels (3 channels). '\n                               f'Only 2D images are supported. Your image shape: {image.shape}')\n    else:\n        raise RuntimeError(\"unexpected image shape. only 2D images and 2D images with color channels (color in \"\n                           \"last dimension) are supported\")\n\n    # rescale image to [0, 255]\n    image = image - image.min()\n    image = image / image.max() * 255\n\n    # create output\n    if mapping is None:\n        uniques = np.sort(pd.unique(segmentation.ravel()))  # np.unique(segmentation)\n        mapping = {i: c for c, i in enumerate(uniques)}\n\n    for l in mapping.keys():\n        image[segmentation == l] += overlay_intensity * np.array(hex_to_rgb(color_cycle[mapping[l]]))\n\n    # rescale result to [0, 255]\n    image = image / image.max() * 255\n    return image.astype(np.uint8)\n\n\ndef select_slice_to_plot(image: np.ndarray, segmentation: np.ndarray) -> int:\n    \"\"\"\n    image and segmentation are expected to be 3D\n\n    selects the slice with the largest amount of fg (regardless of label)\n\n    we give image so that we can easily replace this function if needed\n    \"\"\"\n    fg_mask = segmentation != 0\n    fg_per_slice = fg_mask.sum((1, 2))\n    selected_slice = int(np.argmax(fg_per_slice))\n    return selected_slice\n\n\ndef select_slice_to_plot2(image: np.ndarray, segmentation: np.ndarray) -> int:\n    \"\"\"\n    image and segmentation are expected to be 3D (or 1, x, y)\n\n    selects the slice with the largest amount of fg (how much percent of each class are in each slice? pick slice\n    with highest avg percent)\n\n    we give image so that we can easily replace this function if needed\n    \"\"\"\n    classes = [i for i in np.sort(pd.unique(segmentation.ravel())) if i > 0]\n    fg_per_slice = np.zeros((image.shape[0], len(classes)))\n    for i, c in enumerate(classes):\n        fg_mask = segmentation == c\n        fg_per_slice[:, i] = fg_mask.sum((1, 2))\n        fg_per_slice[:, i] /= fg_per_slice.sum()\n    fg_per_slice = fg_per_slice.mean(1)\n    return int(np.argmax(fg_per_slice))\n\n\ndef plot_overlay(image_file: str, segmentation_file: str, image_reader_writer: BaseReaderWriter, output_file: str,\n                 overlay_intensity: float = 0.6):\n    import matplotlib.pyplot as plt\n\n    image, props = image_reader_writer.read_images((image_file, ))\n    image = image[0]\n    seg, props_seg = image_reader_writer.read_seg(segmentation_file)\n    seg = seg[0]\n\n    assert image.shape == seg.shape, \"image and seg do not have the same shape: %s, %s\" % (\n        image_file, segmentation_file)\n\n    assert image.ndim == 3, 'only 3D images/segs are supported'\n\n    selected_slice = select_slice_to_plot2(image, seg)\n    # print(image.shape, selected_slice)\n\n    overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity)\n\n    plt.imsave(output_file, overlay)\n\n\ndef plot_overlay_preprocessed(dataset: nnUNetBaseDataset, k: str, output_folder: str, overlay_intensity: float = 0.6, channel_idx=0):\n    import matplotlib.pyplot as plt\n    data, seg, _, properties = dataset.load_case(k)\n\n    assert channel_idx < (data.shape[0]), 'This dataset only supports channel index up to %d' % (data.shape[0] - 1)\n\n    image = data[channel_idx]\n    seg = seg[0]\n    selected_slice = select_slice_to_plot2(image, seg)\n\n    seg = np.copy(seg[selected_slice])\n    seg[seg < 0] = 0\n    overlay = generate_overlay(image[selected_slice], seg, overlay_intensity=overlay_intensity)\n\n    plt.imsave(join(output_folder, k + '.png'), overlay)\n\n\ndef multiprocessing_plot_overlay(list_of_image_files, list_of_seg_files, image_reader_writer,\n                                 list_of_output_files, overlay_intensity,\n                                 num_processes=8):\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as p:\n        r = p.starmap_async(plot_overlay, zip(\n            list_of_image_files, list_of_seg_files, [image_reader_writer] * len(list_of_output_files),\n            list_of_output_files, [overlay_intensity] * len(list_of_output_files)\n        ))\n        r.get()\n\n\ndef multiprocessing_plot_overlay_preprocessed(dataset: nnUNetBaseDataset, output_folder, overlay_intensity,\n                                              num_processes=8, channel_idx=0):\n    with multiprocessing.get_context(\"spawn\").Pool(num_processes) as p:\n        r = []\n        for k in dataset.identifiers:\n            r.append(\n                p.starmap_async(plot_overlay_preprocessed,\n                                ((\n                                    dataset, k, output_folder, overlay_intensity, channel_idx\n                                 ),))\n            )\n        _ = [i.get() for i in r]\n\n\ndef generate_overlays_from_raw(dataset_name_or_id: Union[int, str], output_folder: str,\n                               num_processes: int = 8, channel_idx: int = 0, overlay_intensity: float = 0.6):\n    dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)\n    folder = join(nnUNet_raw, dataset_name)\n    dataset_json = load_json(join(folder, 'dataset.json'))\n    dataset = get_filenames_of_train_images_and_targets(folder, dataset_json)\n\n    image_files = [v['images'][channel_idx] for v in dataset.values()]\n    seg_files = [v['label'] for v in dataset.values()]\n\n    assert all([isfile(i) for i in image_files])\n    assert all([isfile(i) for i in seg_files])\n\n    maybe_mkdir_p(output_folder)\n    output_files = [join(output_folder, i + '.png') for i in dataset.keys()]\n\n    image_reader_writer = determine_reader_writer_from_dataset_json(dataset_json, image_files[0])()\n    multiprocessing_plot_overlay(image_files, seg_files, image_reader_writer, output_files, overlay_intensity, num_processes)\n\n\ndef generate_overlays_from_preprocessed(dataset_name_or_id: Union[int, str], output_folder: str,\n                                        num_processes: int = 8, channel_idx: int = 0,\n                                        configuration: str = None,\n                                        plans_identifier: str = 'nnUNetPlans',\n                                        overlay_intensity: float = 0.6):\n    dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)\n    folder = join(nnUNet_preprocessed, dataset_name)\n    if not isdir(folder): raise RuntimeError(\"run preprocessing for that task first\")\n\n    plans = load_json(join(folder, plans_identifier + '.json'))\n    if configuration is None:\n        if '3d_fullres' in plans['configurations'].keys():\n            configuration = '3d_fullres'\n        else:\n            configuration = '2d'\n    cm = ConfigurationManager(plans['configurations'][configuration])\n    preprocessed_folder = join(folder, cm.data_identifier)\n\n    if not isdir(preprocessed_folder):\n        raise RuntimeError(f\"Preprocessed data folder for configuration {configuration} of plans identifier \"\n                           f\"{plans_identifier} ({dataset_name}) does not exist. Run preprocessing for this \"\n                           f\"configuration first!\")\n\n    dc = infer_dataset_class(preprocessed_folder)\n    dataset = dc(preprocessed_folder)\n\n    maybe_mkdir_p(output_folder)\n    multiprocessing_plot_overlay_preprocessed(dataset, output_folder, overlay_intensity=overlay_intensity,\n                                              num_processes=num_processes, channel_idx=channel_idx)\n\n\ndef entry_point_generate_overlay():\n    import argparse\n    parser = argparse.ArgumentParser(\"Plots png overlays of the slice with the most foreground. Note that this \"\n                                     \"disregards spacing information!\")\n    parser.add_argument('-d', type=str, help=\"Dataset name or id\", required=True)\n    parser.add_argument('-o', type=str, help=\"output folder\", required=True)\n    parser.add_argument('-np', type=int, default=default_num_processes, required=False,\n                        help=f\"number of processes used. Default: {default_num_processes}\")\n    parser.add_argument('-channel_idx', type=int, default=0, required=False,\n                        help=\"channel index used (0 = _0000). Default: 0\")\n    parser.add_argument('--use_raw', action='store_true', required=False, help=\"if set then we use raw data. else \"\n                                                                               \"we use preprocessed\")\n    parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',\n                        help='plans identifier. Only used if --use_raw is not set! Default: nnUNetPlans')\n    parser.add_argument('-c', type=str, required=False, default=None,\n                        help='configuration name. Only used if --use_raw is not set! Default: None = '\n                             '3d_fullres if available, else 2d')\n    parser.add_argument('-overlay_intensity', type=float, required=False, default=0.6,\n                        help='overlay intensity. Higher = brighter/less transparent')\n\n\n    args = parser.parse_args()\n\n    if args.use_raw:\n        generate_overlays_from_raw(args.d, args.o, args.np, args.channel_idx,\n                                   overlay_intensity=args.overlay_intensity)\n    else:\n        generate_overlays_from_preprocessed(args.d, args.o, args.np, args.channel_idx, args.c, args.p,\n                                            overlay_intensity=args.overlay_intensity)\n\n\nif __name__ == '__main__':\n    entry_point_generate_overlay()\n"
  },
  {
    "path": "nnunetv2/utilities/plans_handling/__init__.py",
    "content": ""
  },
  {
    "path": "nnunetv2/utilities/plans_handling/plans_handler.py",
    "content": "from __future__ import annotations\n\nimport warnings\n\nfrom copy import deepcopy\nfrom functools import lru_cache, partial\nfrom typing import Union, Tuple, List, Type, Callable\n\nimport numpy as np\nimport torch\n\nfrom nnunetv2.preprocessing.resampling.utils import recursive_find_resampling_fn_by_name\nimport nnunetv2\nfrom batchgenerators.utilities.file_and_folder_operations import load_json, join\n\nfrom nnunetv2.imageio.reader_writer_registry import recursive_find_reader_writer_by_name\nfrom nnunetv2.utilities.find_class_by_name import recursive_find_python_class\nfrom nnunetv2.utilities.label_handling.label_handling import get_labelmanager_class_from_plans\n\n# see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/\nfrom typing import TYPE_CHECKING\nfrom dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm\n\nif TYPE_CHECKING:\n    from nnunetv2.utilities.label_handling.label_handling import LabelManager\n    from nnunetv2.imageio.base_reader_writer import BaseReaderWriter\n    from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor\n    from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner\n\n\nclass ConfigurationManager(object):\n    def __init__(self, configuration_dict: dict):\n        self.configuration = configuration_dict\n\n        # backwards compatibility\n        if 'architecture' not in self.configuration.keys():\n            warnings.warn(\"Detected old nnU-Net plans format. Attempting to reconstruct network architecture \"\n                          \"parameters. If this fails, rerun nnUNetv2_plan_experiment for your dataset. If you use a \"\n                          \"custom architecture, please downgrade nnU-Net to the version you implemented this \"\n                          \"or update your implementation + plans.\")\n            # try to build the architecture information from old plans, modify configuration dict to match new standard\n            unet_class_name = self.configuration[\"UNet_class_name\"]\n            if unet_class_name == \"PlainConvUNet\":\n                network_class_name = \"dynamic_network_architectures.architectures.unet.PlainConvUNet\"\n            elif unet_class_name == 'ResidualEncoderUNet':\n                network_class_name = \"dynamic_network_architectures.architectures.residual_unet.ResidualEncoderUNet\"\n            else:\n                raise RuntimeError(f'Unknown architecture {unet_class_name}. This conversion only supports '\n                                   f'PlainConvUNet and ResidualEncoderUNet')\n\n            n_stages = len(self.configuration[\"n_conv_per_stage_encoder\"])\n\n            dim = len(self.configuration[\"patch_size\"])\n            conv_op = convert_dim_to_conv_op(dim)\n            instnorm = get_matching_instancenorm(dimension=dim)\n\n            convs_or_blocks = \"n_conv_per_stage\" if unet_class_name == \"PlainConvUNet\" else \"n_blocks_per_stage\"\n\n            arch_dict = {\n                'network_class_name': network_class_name,\n                'arch_kwargs': {\n                    \"n_stages\": n_stages,\n                    \"features_per_stage\": [min(self.configuration[\"UNet_base_num_features\"] * 2 ** i,\n                                               self.configuration[\"unet_max_num_features\"])\n                                           for i in range(n_stages)],\n                    \"conv_op\": conv_op.__module__ + '.' + conv_op.__name__,\n                    \"kernel_sizes\": deepcopy(self.configuration[\"conv_kernel_sizes\"]),\n                    \"strides\": deepcopy(self.configuration[\"pool_op_kernel_sizes\"]),\n                    convs_or_blocks: deepcopy(self.configuration[\"n_conv_per_stage_encoder\"]),\n                    \"n_conv_per_stage_decoder\": deepcopy(self.configuration[\"n_conv_per_stage_decoder\"]),\n                    \"conv_bias\": True,\n                    \"norm_op\": instnorm.__module__ + '.' + instnorm.__name__,\n                    \"norm_op_kwargs\": {\n                        \"eps\": 1e-05,\n                        \"affine\": True\n                    },\n                    \"dropout_op\": None,\n                    \"dropout_op_kwargs\": None,\n                    \"nonlin\": \"torch.nn.LeakyReLU\",\n                    \"nonlin_kwargs\": {\n                        \"inplace\": True\n                    }\n                },\n                # these need to be imported with locate in order to use them:\n                # `conv_op = pydoc.locate(architecture_kwargs['conv_op'])`\n                \"_kw_requires_import\": [\n                    \"conv_op\",\n                    \"norm_op\",\n                    \"dropout_op\",\n                    \"nonlin\"\n                ]\n            }\n            del self.configuration[\"UNet_class_name\"], self.configuration[\"UNet_base_num_features\"], \\\n                self.configuration[\"n_conv_per_stage_encoder\"], self.configuration[\"n_conv_per_stage_decoder\"], \\\n                self.configuration[\"num_pool_per_axis\"], self.configuration[\"pool_op_kernel_sizes\"],\\\n                self.configuration[\"conv_kernel_sizes\"], self.configuration[\"unet_max_num_features\"]\n            self.configuration[\"architecture\"] = arch_dict\n\n    def __repr__(self):\n        return self.configuration.__repr__()\n\n    @property\n    def data_identifier(self) -> str:\n        return self.configuration['data_identifier']\n\n    @property\n    def preprocessor_name(self) -> str:\n        return self.configuration['preprocessor_name']\n\n    @property\n    @lru_cache(maxsize=1)\n    def preprocessor_class(self) -> Type[DefaultPreprocessor]:\n        preprocessor_class = recursive_find_python_class(join(nnunetv2.__path__[0], \"preprocessing\"),\n                                                         self.preprocessor_name,\n                                                         current_module=\"nnunetv2.preprocessing\")\n        return preprocessor_class\n\n    @property\n    def batch_size(self) -> int:\n        return self.configuration['batch_size']\n\n    @property\n    def patch_size(self) -> List[int]:\n        return self.configuration['patch_size']\n\n    @property\n    def median_image_size_in_voxels(self) -> List[int]:\n        return self.configuration['median_image_size_in_voxels']\n\n    @property\n    def spacing(self) -> List[float]:\n        return self.configuration['spacing']\n\n    @property\n    def normalization_schemes(self) -> List[str]:\n        return self.configuration['normalization_schemes']\n\n    @property\n    def use_mask_for_norm(self) -> List[bool]:\n        return self.configuration['use_mask_for_norm']\n\n    @property\n    def network_arch_class_name(self) -> str:\n        return self.configuration['architecture']['network_class_name']\n\n    @property\n    def network_arch_init_kwargs(self) -> dict:\n        return self.configuration['architecture']['arch_kwargs']\n\n    @property\n    def network_arch_init_kwargs_req_import(self) -> Union[Tuple[str, ...], List[str]]:\n        return self.configuration['architecture']['_kw_requires_import']\n\n    @property\n    def pool_op_kernel_sizes(self) -> Tuple[Tuple[int, ...], ...]:\n        return self.configuration['architecture']['arch_kwargs']['strides']\n\n    @property\n    @lru_cache(maxsize=1)\n    def resampling_fn_data(self) -> Callable[\n        [Union[torch.Tensor, np.ndarray],\n         Union[Tuple[int, ...], List[int], np.ndarray],\n         Union[Tuple[float, ...], List[float], np.ndarray],\n         Union[Tuple[float, ...], List[float], np.ndarray]\n         ],\n        Union[torch.Tensor, np.ndarray]]:\n        fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_data'])\n        fn = partial(fn, **self.configuration['resampling_fn_data_kwargs'])\n        return fn\n\n    @property\n    @lru_cache(maxsize=1)\n    def resampling_fn_probabilities(self) -> Callable[\n        [Union[torch.Tensor, np.ndarray],\n         Union[Tuple[int, ...], List[int], np.ndarray],\n         Union[Tuple[float, ...], List[float], np.ndarray],\n         Union[Tuple[float, ...], List[float], np.ndarray]\n         ],\n        Union[torch.Tensor, np.ndarray]]:\n        fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_probabilities'])\n        fn = partial(fn, **self.configuration['resampling_fn_probabilities_kwargs'])\n        return fn\n\n    @property\n    @lru_cache(maxsize=1)\n    def resampling_fn_seg(self) -> Callable[\n        [Union[torch.Tensor, np.ndarray],\n         Union[Tuple[int, ...], List[int], np.ndarray],\n         Union[Tuple[float, ...], List[float], np.ndarray],\n         Union[Tuple[float, ...], List[float], np.ndarray]\n         ],\n        Union[torch.Tensor, np.ndarray]]:\n        fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_seg'])\n        fn = partial(fn, **self.configuration['resampling_fn_seg_kwargs'])\n        return fn\n\n    @property\n    def batch_dice(self) -> bool:\n        return self.configuration['batch_dice']\n\n    @property\n    def next_stage_names(self) -> Union[List[str], None]:\n        ret = self.configuration.get('next_stage')\n        if ret is not None:\n            if isinstance(ret, str):\n                ret = [ret]\n        return ret\n\n    @property\n    def previous_stage_name(self) -> Union[str, None]:\n        return self.configuration.get('previous_stage')\n\n\nclass PlansManager(object):\n    def __init__(self, plans_file_or_dict: Union[str, dict]):\n        \"\"\"\n        Why do we need this?\n        1) resolve inheritance in configurations\n        2) expose otherwise annoying stuff like getting the label manager or IO class from a string\n        3) clearly expose the things that are in the plans instead of hiding them in a dict\n        4) cache shit\n\n        This class does not prevent you from going wild. You can still use the plans directly if you prefer\n        (PlansHandler.plans['key'])\n        \"\"\"\n        self.plans = plans_file_or_dict if isinstance(plans_file_or_dict, dict) else load_json(plans_file_or_dict)\n\n    def __repr__(self):\n        return self.plans.__repr__()\n\n    def _internal_resolve_configuration_inheritance(self, configuration_name: str,\n                                                    visited: Tuple[str, ...] = None) -> dict:\n        if configuration_name not in self.plans['configurations'].keys():\n            raise ValueError(f'The configuration {configuration_name} does not exist in the plans I have. Valid '\n                             f'configuration names are {list(self.plans[\"configurations\"].keys())}.')\n        configuration = deepcopy(self.plans['configurations'][configuration_name])\n        if 'inherits_from' in configuration:\n            parent_config_name = configuration['inherits_from']\n\n            if visited is None:\n                visited = (configuration_name,)\n            else:\n                if parent_config_name in visited:\n                    raise RuntimeError(f\"Circular dependency detected. The following configurations were visited \"\n                                       f\"while solving inheritance (in that order!): {visited}. \"\n                                       f\"Current configuration: {configuration_name}. Its parent configuration \"\n                                       f\"is {parent_config_name}.\")\n                visited = (*visited, configuration_name)\n\n            base_config = self._internal_resolve_configuration_inheritance(parent_config_name, visited)\n            base_config.update(configuration)\n            configuration = base_config\n        return configuration\n\n    @lru_cache(maxsize=10)\n    def get_configuration(self, configuration_name: str):\n        if configuration_name not in self.plans['configurations'].keys():\n            raise RuntimeError(f\"Requested configuration {configuration_name} not found in plans. \"\n                               f\"Available configurations: {list(self.plans['configurations'].keys())}\")\n\n        configuration_dict = self._internal_resolve_configuration_inheritance(configuration_name)\n        return ConfigurationManager(configuration_dict)\n\n    @property\n    def dataset_name(self) -> str:\n        return self.plans['dataset_name']\n\n    @property\n    def plans_name(self) -> str:\n        return self.plans['plans_name']\n\n    @property\n    def original_median_spacing_after_transp(self) -> List[float]:\n        return self.plans['original_median_spacing_after_transp']\n\n    @property\n    def original_median_shape_after_transp(self) -> List[float]:\n        return self.plans['original_median_shape_after_transp']\n\n    @property\n    @lru_cache(maxsize=1)\n    def image_reader_writer_class(self) -> Type[BaseReaderWriter]:\n        return recursive_find_reader_writer_by_name(self.plans['image_reader_writer'])\n\n    @property\n    def transpose_forward(self) -> List[int]:\n        return self.plans['transpose_forward']\n\n    @property\n    def transpose_backward(self) -> List[int]:\n        return self.plans['transpose_backward']\n\n    @property\n    def available_configurations(self) -> List[str]:\n        return list(self.plans['configurations'].keys())\n\n    @property\n    @lru_cache(maxsize=1)\n    def experiment_planner_class(self) -> Type[ExperimentPlanner]:\n        planner_name = self.experiment_planner_name\n        experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], \"experiment_planning\"),\n                                                         planner_name,\n                                                         current_module=\"nnunetv2.experiment_planning\")\n        return experiment_planner\n\n    @property\n    def experiment_planner_name(self) -> str:\n        return self.plans['experiment_planner_used']\n\n    @property\n    @lru_cache(maxsize=1)\n    def label_manager_class(self) -> Type[LabelManager]:\n        return get_labelmanager_class_from_plans(self.plans)\n\n    def get_label_manager(self, dataset_json: dict, **kwargs) -> LabelManager:\n        return self.label_manager_class(label_dict=dataset_json['labels'],\n                                        regions_class_order=dataset_json.get('regions_class_order'),\n                                        **kwargs)\n\n    @property\n    def foreground_intensity_properties_per_channel(self) -> dict:\n        if 'foreground_intensity_properties_per_channel' not in self.plans.keys():\n            if 'foreground_intensity_properties_by_modality' in self.plans.keys():\n                return self.plans['foreground_intensity_properties_by_modality']\n        return self.plans['foreground_intensity_properties_per_channel']\n\n\nif __name__ == '__main__':\n    from nnunetv2.paths import nnUNet_preprocessed\n    from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\n\n    plans = load_json(join(nnUNet_preprocessed, maybe_convert_to_dataset_name(3), 'nnUNetPlans.json'))\n    # build new configuration that inherits from 3d_fullres\n    plans['configurations']['3d_fullres_bs4'] = {\n        'batch_size': 4,\n        'inherits_from': '3d_fullres'\n    }\n    # now get plans and configuration managers\n    plans_manager = PlansManager(plans)\n    configuration_manager = plans_manager.get_configuration('3d_fullres_bs4')\n    print(configuration_manager)  # look for batch size 4\n"
  },
  {
    "path": "nnunetv2/utilities/utils.py",
    "content": "#    Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center\n#    (DKFZ), Heidelberg, Germany\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\nimport os.path\nfrom functools import lru_cache\nfrom typing import Union\n\nfrom batchgenerators.utilities.file_and_folder_operations import *\nimport numpy as np\nimport re\n\nfrom nnunetv2.paths import nnUNet_raw\nfrom multiprocessing import Pool\n\n\ndef get_identifiers_from_splitted_dataset_folder(folder: str, file_ending: str):\n    files = subfiles(folder, suffix=file_ending, join=False)\n    # all files have a 4 digit channel index (_XXXX)\n    crop = len(file_ending) + 5\n    files = [i[:-crop] for i in files]\n    # only unique image ids\n    files = np.unique(files)\n    return files\n\n\ndef create_paths_fn(folder, files, file_ending, f):\n    p = re.compile(re.escape(f) + r\"_\\d\\d\\d\\d\" + re.escape(file_ending))            \n    return [join(folder, i) for i in files if p.fullmatch(i)]\n\n\ndef create_lists_from_splitted_dataset_folder(folder: str, file_ending: str, identifiers: List[str] = None, num_processes: int = 12) -> List[\n    List[str]]:\n    \"\"\"\n    does not rely on dataset.json\n    \"\"\"\n    if identifiers is None:\n        identifiers = get_identifiers_from_splitted_dataset_folder(folder, file_ending)\n    files = subfiles(folder, suffix=file_ending, join=False, sort=True)\n    list_of_lists = []\n\n    params_list = [(folder, files, file_ending, f) for f in identifiers]\n    with Pool(processes=num_processes) as pool:\n        list_of_lists = pool.starmap(create_paths_fn, params_list)\n        \n    return list_of_lists\n\n\ndef get_filenames_of_train_images_and_targets(raw_dataset_folder: str, dataset_json: dict = None):\n    if dataset_json is None:\n        dataset_json = load_json(join(raw_dataset_folder, 'dataset.json'))\n\n    if 'dataset' in dataset_json.keys():\n        dataset = dataset_json['dataset']\n        for k in dataset.keys():\n            expanded_label_file = os.path.expandvars(dataset[k]['label'])\n            dataset[k]['label'] = os.path.abspath(join(raw_dataset_folder, expanded_label_file)) if not os.path.isabs(expanded_label_file) else expanded_label_file\n            dataset[k]['images'] = [os.path.abspath(join(raw_dataset_folder, os.path.expandvars(i))) if not os.path.isabs(os.path.expandvars(i)) else os.path.expandvars(i) for i in dataset[k]['images']]\n    else:\n        identifiers = get_identifiers_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'])\n        images = create_lists_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'], identifiers)\n        segs = [join(raw_dataset_folder, 'labelsTr', i + dataset_json['file_ending']) for i in identifiers]\n        dataset = {i: {'images': im, 'label': se} for i, im, se in zip(identifiers, images, segs)}\n    return dataset\n\n\nif __name__ == '__main__':\n    print(get_filenames_of_train_images_and_targets(join(nnUNet_raw, 'Dataset002_Heart')))\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"nnunetv2\"\nversion = \"2.6.4\"\nrequires-python = \">=3.10\"\ndescription = \"nnU-Net is a framework for out-of-the box image segmentation.\"\nreadme = \"readme.md\"\nlicense = { file = \"LICENSE\" }\nauthors = [\n    { name = \"Fabian Isensee\", email = \"f.isensee@dkfz-heidelberg.de\"},\n    { name = \"Helmholtz Imaging Applied Computer Vision Lab\" }\n]\nclassifiers = [\n    \"Development Status :: 5 - Production/Stable\",\n    \"Intended Audience :: Developers\",\n    \"Intended Audience :: Science/Research\",\n    \"Intended Audience :: Healthcare Industry\",\n    \"Programming Language :: Python :: 3\",\n    \"License :: OSI Approved :: Apache Software License\",\n    \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n    \"Topic :: Scientific/Engineering :: Image Recognition\",\n    \"Topic :: Scientific/Engineering :: Medical Science Apps.\",\n]\nkeywords = [\n    'deep learning',\n    'image segmentation',\n    'semantic segmentation',\n    'medical image analysis',\n    'medical image segmentation',\n    'nnU-Net',\n    'nnunet'\n]\ndependencies = [\n    \"torch>=2.1.2,!=2.9.*\",\n    \"acvl-utils>=0.2.3,<0.3\",  # 0.3 may bring breaking changes. Careful!\n    \"dynamic-network-architectures>=0.4.1,<0.5\",\n    \"tqdm\",\n    \"scipy\",\n    \"batchgenerators>=0.25.1\",\n    \"numpy>=1.24\",\n    \"scikit-learn\",\n    \"scikit-image>=0.19.3\",\n    \"SimpleITK>=2.2.1\",\n    \"pandas\",\n    \"graphviz\",\n    'tifffile',\n    'requests',\n    \"nibabel\",\n    \"matplotlib\",\n    \"seaborn\",\n    \"imagecodecs\",\n    \"yacs\",\n    \"batchgeneratorsv2>=0.3.0\",\n    \"einops\",\n    \"blosc2>=3.0.0b1\"\n]\n\n[project.urls]\nhomepage = \"https://github.com/MIC-DKFZ/nnUNet\"\nrepository = \"https://github.com/MIC-DKFZ/nnUNet\"\n\n[project.scripts]\nnnUNetv2_plan_and_preprocess = \"nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_and_preprocess_entry\"\nnnUNetv2_extract_fingerprint = \"nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:extract_fingerprint_entry\"\nnnUNetv2_plan_experiment = \"nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_experiment_entry\"\nnnUNetv2_preprocess = \"nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:preprocess_entry\"\nnnUNetv2_train = \"nnunetv2.run.run_training:run_training_entry\"\nnnUNetv2_predict_from_modelfolder = \"nnunetv2.inference.predict_from_raw_data:predict_entry_point_modelfolder\"\nnnUNetv2_predict = \"nnunetv2.inference.predict_from_raw_data:predict_entry_point\"\nnnUNetv2_convert_old_nnUNet_dataset = \"nnunetv2.dataset_conversion.convert_raw_dataset_from_old_nnunet_format:convert_entry_point\"\nnnUNetv2_find_best_configuration = \"nnunetv2.evaluation.find_best_configuration:find_best_configuration_entry_point\"\nnnUNetv2_determine_postprocessing = \"nnunetv2.postprocessing.remove_connected_components:entry_point_determine_postprocessing_folder\"\nnnUNetv2_apply_postprocessing = \"nnunetv2.postprocessing.remove_connected_components:entry_point_apply_postprocessing\"\nnnUNetv2_ensemble = \"nnunetv2.ensembling.ensemble:entry_point_ensemble_folders\"\nnnUNetv2_accumulate_crossval_results = \"nnunetv2.evaluation.find_best_configuration:accumulate_crossval_results_entry_point\"\nnnUNetv2_plot_overlay_pngs = \"nnunetv2.utilities.overlay_plots:entry_point_generate_overlay\"\nnnUNetv2_download_pretrained_model_by_url = \"nnunetv2.model_sharing.entry_points:download_by_url\"\nnnUNetv2_install_pretrained_model_from_zip = \"nnunetv2.model_sharing.entry_points:install_from_zip_entry_point\"\nnnUNetv2_export_model_to_zip = \"nnunetv2.model_sharing.entry_points:export_pretrained_model_entry\"\nnnUNetv2_move_plans_between_datasets = \"nnunetv2.experiment_planning.plans_for_pretraining.move_plans_between_datasets:entry_point_move_plans_between_datasets\"\nnnUNetv2_evaluate_folder = \"nnunetv2.evaluation.evaluate_predictions:evaluate_folder_entry_point\"\nnnUNetv2_evaluate_simple = \"nnunetv2.evaluation.evaluate_predictions:evaluate_simple_entry_point\"\nnnUNetv2_convert_MSD_dataset = \"nnunetv2.dataset_conversion.convert_MSD_dataset:entry_point\"\n\n[project.optional-dependencies]\ndev = [\n    \"black\",\n    \"ruff\",\n    \"pre-commit\"\n]\n\n[build-system]\nrequires = [\"setuptools>=67.8.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[tool.codespell]\nskip = '.git,*.pdf,*.svg'\n#\n# ignore-words-list = ''\n"
  },
  {
    "path": "readme.md",
    "content": "# Welcome to the new nnU-Net!\n\nClick [here](https://github.com/MIC-DKFZ/nnUNet/tree/nnunetv1) if you were looking for the old one instead.\n\nComing from V1? Check out the [TLDR Migration Guide](documentation/tldr_migration_guide_from_v1.md). Reading the rest of the documentation is still strongly recommended ;-)\n\n## **2025-10-23 There seems to be a [severe performance regression with torch 2.9.0 and 3D convs](https://github.com/pytorch/pytorch/issues/166122) (when using AMP). Please use torch 2.8.0 or lower with nnU-Net!**\n\n\n## **2024-04-18 UPDATE: New residual encoder UNet presets available!**\nResidual encoder UNet presets substantially improve segmentation performance.\nThey ship for a variety of GPU memory targets. It's all awesome stuff, promised! \nRead more :point_right: [here](documentation/resenc_presets.md) :point_left:\n\nAlso check out our [new paper](https://arxiv.org/pdf/2404.09556.pdf) on systematically benchmarking recent developments in medical image segmentation. You might be surprised!\n\n# What is nnU-Net?\nImage datasets are enormously diverse: image dimensionality (2D, 3D), modalities/input channels (RGB image, CT, MRI, microscopy, ...), \nimage sizes, voxel sizes, class ratio, target structure properties and more change substantially between datasets. \nTraditionally, given a new problem, a tailored solution needs to be manually designed and optimized  - a process that \nis prone to errors, not scalable and where success is overwhelmingly determined by the skill of the experimenter. Even \nfor experts, this process is anything but simple: there are not only many design choices and data properties that need to \nbe considered, but they are also tightly interconnected, rendering reliable manual pipeline optimization all but impossible! \n\n![nnU-Net overview](documentation/assets/nnU-Net_overview.png)\n\n**nnU-Net is a semantic segmentation method that automatically adapts to a given dataset. It will analyze the provided \ntraining cases and automatically configure a matching U-Net-based segmentation pipeline. No expertise required on your \nend! You can simply train the models and use them for your application**.\n\nUpon release, nnU-Net was evaluated on 23 datasets belonging to competitions from the biomedical domain. Despite competing \nwith handcrafted solutions for each respective dataset, nnU-Net's fully automated pipeline scored several first places on \nopen leaderboards! Since then nnU-Net has stood the test of time: it continues to be used as a baseline and method \ndevelopment framework ([9 out of 10 challenge winners at MICCAI 2020](https://arxiv.org/abs/2101.00232) and 5 out of 7 \nin MICCAI 2021 built their methods on top of nnU-Net, \n [we won AMOS2022 with nnU-Net](https://amos22.grand-challenge.org/final-ranking/))!\n\nPlease cite the [following paper](https://www.google.com/url?q=https://www.nature.com/articles/s41592-020-01008-z&sa=D&source=docs&ust=1677235958581755&usg=AOvVaw3dWL0SrITLhCJUBiNIHCQO) when using nnU-Net:\n\n    Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring \n    method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.\n\n\n## What can nnU-Net do for you?\nIf you are a **domain scientist** (biologist, radiologist, ...) looking to analyze your own images, nnU-Net provides \nan out-of-the-box solution that is all but guaranteed to provide excellent results on your individual dataset. Simply \nconvert your dataset into the nnU-Net format and enjoy the power of AI - no expertise required!\n\nIf you are an **AI researcher** developing segmentation methods, nnU-Net:\n- offers a fantastic out-of-the-box applicable baseline algorithm to compete against\n- can act as a method development framework to test your contribution on a large number of datasets without having to \ntune individual pipelines (for example evaluating a new loss function)\n- provides a strong starting point for further dataset-specific optimizations. This is particularly used when competing \nin segmentation challenges\n- provides a new perspective on the design of segmentation methods: maybe you can find better connections between \ndataset properties and best-fitting segmentation pipelines?\n\n## What is the scope of nnU-Net?\nnnU-Net is built for semantic segmentation. It can handle 2D and 3D images with arbitrary \ninput modalities/channels. It can understand voxel spacings, anisotropies and is robust even when classes are highly\nimbalanced.\n\nnnU-Net relies on supervised learning, which means that you need to provide training cases for your application. The number of \nrequired training cases varies heavily depending on the complexity of the segmentation problem. No \none-fits-all number can be provided here! nnU-Net does not require more training cases than other solutions - maybe \neven less due to our extensive use of data augmentation. \n\nnnU-Net expects to be able to process entire images at once during preprocessing and postprocessing, so it cannot \nhandle enormous images. As a reference: we tested images from 40x40x40 pixels all the way up to 1500x1500x1500 in 3D \nand 40x40 up to ~30000x30000 in 2D! If your RAM allows it, larger is always possible.\n\n## How does nnU-Net work?\nGiven a new dataset, nnU-Net will systematically analyze the provided training cases and create a 'dataset fingerprint'. \nnnU-Net then creates several U-Net configurations for each dataset: \n- `2d`: a 2D U-Net (for 2D and 3D datasets)\n- `3d_fullres`: a 3D U-Net that operates on a high image resolution (for 3D datasets only)\n- `3d_lowres` → `3d_cascade_fullres`: a 3D U-Net cascade where first a 3D U-Net operates on low resolution images and \nthen a second high-resolution 3D U-Net refined the predictions of the former (for 3D datasets with large image sizes only)\n\n**Note that not all U-Net configurations are created for all datasets. In datasets with small image sizes, the \nU-Net cascade (and with it the 3d_lowres configuration) is omitted because the patch size of the full \nresolution U-Net already covers a large part of the input images.**\n\nnnU-Net configures its segmentation pipelines based on a three-step recipe:\n- **Fixed parameters** are not adapted. During development of nnU-Net we identified a robust configuration (that is, certain architecture and training properties) that can \nsimply be used all the time. This includes, for example, nnU-Net's loss function, (most of the) data augmentation strategy and learning rate.\n- **Rule-based parameters** use the dataset fingerprint to adapt certain segmentation pipeline properties by following \nhard-coded heuristic rules. For example, the network topology (pooling behavior and depth of the network architecture) \nare adapted to the patch size; the patch size, network topology and batch size are optimized jointly given some GPU \nmemory constraint. \n- **Empirical parameters** are essentially trial-and-error. For example the selection of the best U-net configuration \nfor the given dataset (2D, 3D full resolution, 3D low resolution, 3D cascade) and the optimization of the postprocessing strategy.\n\n## How to get started?\nRead these:\n- [Installation instructions](documentation/installation_instructions.md)\n- [Dataset conversion](documentation/dataset_format.md)\n- [Usage instructions](documentation/how_to_use_nnunet.md)\n\nAdditional information:\n- [Contributing to nnU-Net](CONTRIBUTING.md)\n- [Learning from sparse annotations (scribbles, slices)](documentation/ignore_label.md)\n- [Region-based training](documentation/region_based_training.md)\n- [Manual data splits](documentation/manual_data_splits.md)\n- [Pretraining and finetuning](documentation/pretraining_and_finetuning.md)\n- [Intensity Normalization in nnU-Net](documentation/explanation_normalization.md)\n- [Training logging (Local + Weights & Biases)](documentation/explanation_logging.md)\n- [Manually editing nnU-Net configurations](documentation/explanation_plans_files.md)\n- [Extending nnU-Net](documentation/extending_nnunet.md)\n- [What is different in V2?](documentation/changelog.md)\n\nCompetitions:\n- [AutoPET II](documentation/competitions/AutoPETII.md)\n\n[//]: # (- [Ignore label]&#40;documentation/ignore_label.md&#41;)\n\n## Where does nnU-Net perform well and where does it not perform?\nnnU-Net excels in segmentation problems that need to be solved by training from scratch, \nfor example: research applications that feature non-standard image modalities and input channels,\nchallenge datasets from the biomedical domain, majority of 3D segmentation problems, etc . We have yet to find a \ndataset for which nnU-Net's working principle fails!\n\nNote: On standard segmentation \nproblems, such as 2D RGB images in ADE20k and Cityscapes, fine-tuning a foundation model (that was pretrained on a large corpus of \nsimilar images, e.g. Imagenet 22k, JFT-300M) will provide better performance than nnU-Net! That is simply because these \nmodels allow much better initialization. Foundation models are not supported by nnU-Net as \nthey 1) are not useful for segmentation problems that deviate from the standard setting (see above mentioned \ndatasets), 2) would typically only support 2D architectures and 3) conflict with our core design principle of carefully adapting \nthe network topology for each dataset (if the topology is changed one can no longer transfer pretrained weights!) \n\n## What happened to the old nnU-Net?\nThe core of the old nnU-Net was hacked together in a short time period while participating in the Medical Segmentation \nDecathlon challenge in 2018. Consequently, code structure and quality were not the best. Many features \nwere added later on and didn't quite fit into the nnU-Net design principles. Overall quite messy, really. And annoying to work with.\n\nnnU-Net V2 is a complete overhaul. The \"delete everything and start again\" kind. So everything is better \n(in the author's opinion haha). While the segmentation performance [remains the same](https://docs.google.com/spreadsheets/d/13gqjIKEMPFPyMMMwA1EML57IyoBjfC3-QCTn4zRN_Mg/edit?usp=sharing), a lot of cool stuff has been added. \nIt is now also much easier to use it as a development framework and to manually fine-tune its configuration to new \ndatasets. A big driver for the reimplementation was also the emergence of [Helmholtz Imaging](http://helmholtz-imaging.de), \nprompting us to extend nnU-Net to more image formats and domains. Take a look [here](documentation/changelog.md) for some highlights.\n\n# Acknowledgements\n<img src=\"documentation/assets/HI_Logo.png\" height=\"100px\" />\n\n<img src=\"documentation/assets/dkfz_logo.png\" height=\"100px\" />\n\nnnU-Net is developed and maintained by the Applied Computer Vision Lab (ACVL) of [Helmholtz Imaging](http://helmholtz-imaging.de) \nand the [Division of Medical Image Computing](https://www.dkfz.de/en/mic/index.php) at the \n[German Cancer Research Center (DKFZ)](https://www.dkfz.de/en/index.html).\n"
  },
  {
    "path": "setup.py",
    "content": "import setuptools\n\nif __name__ == \"__main__\":\n    setuptools.setup()\n"
  }
]